pytorch-Flatten操作

1 class Flatten(nn.Module):
2     def __init__(self):
3         super(Flatten,self).__init__()
4         
5     def forward(self,input):
6         shape = torch.prod(torch.tensor(x.shape[1:])).item()
7         # -1 把第一个维度保持住
8         return x.view(-1,shape)