pytorch 中改变tensor维度,transpose、拼接

具体示例如下,注意观察维度的变化

1.改变tensor维度的操作:transpose、view、permute、t()、expand、repeat

#coding=utf-8
import  torch

def change_tensor_shape():
    x=torch.randn(2,4,3)
    s=x.transpose(1,2) #shape=[2,3,4]
    y=x.view(2,3,4) #shape=[2,3,4]
    z=x.permute(0,2,1) #shape=[2,3,4]

    #tensor.t()只能转化 a 2D tensor
    m=torch.randn(2,3)#shape=[2,3]
    n=m.t()#shape=[3,2]
    print(m)
    print(n)

    #返回当前张量在某个维度为1扩展为更大的张量
    x = torch.Tensor([[1], [2], [3]])#shape=[3,1]
    t=x.expand(3, 4)
    print(t)
    '''
    tensor([[1., 1., 1., 1.],
        [2., 2., 2., 2.],
        [3., 3., 3., 3.]])
    '''

    #沿着特定的维度重复这个张量
    x=torch.Tensor([[1,2,3]])
    t=x.repeat(3, 2)
    print(t)
    '''
    tensor([[1., 2., 3., 1., 2., 3.],
        [1., 2., 3., 1., 2., 3.],
        [1., 2., 3., 1., 2., 3.]])
    '''
    x = torch.randn(2, 3, 4)
    t=x.repeat(2, 1, 3) #shape=[4, 3, 12]

if __name__=='__main__':
    change_tensor_shape()

2.tensor的拼接:cat、stack

除了要拼接的维度可以不相等,其他维度必须相等

#coding=utf-8
import  torch


def cat_and_stack():

    x = torch.randn(2,3,6)
    y = torch.randn(2,4,6)
    c=torch.cat((x,y),1)
    #c=(2*7*6)
    print(c.size)

    """
    而stack则会增加新的维度。
    如对两个1*2维的tensor在第0个维度上stack,则会变为2*1*2的tensor;在第1个维度上stack,则会变为1*2*2的tensor。
    """
    a = torch.rand((1, 2))
    b = torch.rand((1, 2))
    c = torch.stack((a, b), 0)
    print(c.size())

if __name__=='__main__':
    cat_and_stack()

3.压缩和扩展维度:改变tensor中只有1个维度的tensor

torch.squeeze(input, dim=None, out=None) → Tensor

除去输入张量input中数值为1的维度,并返回新的张量。如果输入张量的形状为(A×1×B×C×1×D) 那么输出张量的形状为(A×B×C×D)

当通过dim参数指定维度时,维度压缩操作只会在指定的维度上进行。如果输入向量的形状为(A×1×B),

squeeze(input, 0)会保持张量的维度不变,只有在执行squeeze(input, 1)时,输入张量的形状会被压缩至(A×B) 。

如果一个张量只有1个维度,那么它不会受到上述方法的影响。

#coding=utf-8
import  torch


def squeeze_tensor():
    x = torch.Tensor(1,3)
    y=torch.squeeze(x, 0)
    print("y:",y)
    y=torch.unsqueeze(y, 1)
    print("y:",y)

if __name__=='__main__':
    squeeze_tensor()