[Pytorch]Pytorch加载预训练模型,转

转自:https://blog.csdn.net/Vivianyzw/article/details/81061765

东风的地方

1. 直接加载预训练模型

在训练的时候可能需要中断一下,然后继续训练,也就是简单的从保存的模型中加载参数权重:

  1. net = SNet()

  2. net.load_state_dict(torch.load("model_1599.pkl"))

这种方式是针对于之前保存模型时以保存参数的格式使用的:

torch.save(net.state_dict(), "model/model_1599.pkl")

pytorch官网更推荐上述模型保存方法,也据说这种方式比下一种更快一点。

下面介绍第二种模型保存和加载的方式:

  1. net = SNet()

  2. torch.save(net, "model_1599.pkl")

  3. snet = torch.load("model_1599.pkl")

这种方式会将整个网络保存下来,数据量会更大,会消耗更多的时间,占用内存也更高。

2. 加载一部分预训练模型

模型可能是一些经典的模型改掉一部分,比如一般算法中提取特征的网络常见的会直接使用vgg16的features extraction部分,也就是在训练的时候可以直接加载已经在imagenet上训练好的预训练参数,这种方式实现如下:

  1. net = SNet()

  2. model_dict = net.state_dict()

  3. vgg16 = models.vgg16(pretrained=True)

  4. pretrained_dict = vgg16.state_dict()

  5. pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}

  6. model_dict.update(pretrained_dict)

  7. net.load_state_dict(model_dict)

也就是在网络中state_dict部分,属于vgg16的,替换成vgg16预训练模型里的参数(代码里的k:v for k,v in pretrained_dict.items() if k in model_dict),其他保持不变。

3. 微调经典网络

因为pytorch中的torchvision给出了很多经典常用模型,并附加了预训练模型。利用好这些训练好的基础网络可以加快不少自己的训练速度。

首先比如加载vgg16(带有预训练参数的形式):

  1. import torchvision.models as models

  2. vgg16 = models.vgg16(pretrained=True)

比如,网络第一层本来是Conv2d(3, 64, 3, 1, 1),想修改成Conv2d(4, 64, 3, 1 ,1),那直接赋值就可以了:

  1. import torch.nn as nn

  2. vgg16.features[0]=nn.Conv2d(4, 64, 3, 1, 1)

4. 修改经典网络

这个比上面微调修改的地方要多一些,但是想介绍一下这样的修改方式。

先简单介绍一下我需要需改的部分,在vgg16的基础模型下,每一个卷积都要加一个dropout层,并将ReLU激活函数换成PReLU,最后两层的Pooling层stride改成1。直接上代码:

  1. def feature_layer():

  2. layers = []

  3. pool1 = ['4', '9', '16']

  4. pool2 = ['23', '30']

  5. vgg16 = models.vgg16(pretrained=True).features

  6. for name, layer in vgg16._modules.items():

  7. if isinstance(layer, nn.Conv2d):

  8. layers += [layer, nn.Dropout2d(0.5), nn.PReLU()]

  9. elif name in pool1:

  10. layers += [layer]

  11. elif name == pool2[0]:

  12. layers += [nn.MaxPool2d(2, 1, 1)]

  13. elif name == pool2[1]:

  14. layers += [nn.MaxPool2d(2, 1, 0)]

  15. else:

  16. continue

  17. features = nn.Sequential(*layers)

  18. #feat3 = features[0:24]

  19. return features

大概的思路就是,创建一个新的网络(layers列表), 遍历vgg16里每一层,如果遇到卷积层(if isinstance(layer, nn.Conv2d)就先把该层(Conv2d)保持原样加进去,随后增加一个dropout层,再加一个PReLU层。然后如果遇到最后两层pool,就修改响应参数加进去,其他的pool正常加载。 最后将这个layers列表转成网络的nn.Sequential的形式,最后返回features。然后再你的新的网络层就可以用以下方式来加载:

  1. class SNet(nn.Module):

  2. def __init__(self):

  3. super(SNet, self).__init__()

  4. self.features = feature_layer()

  5. def forward(self, x):

  6. x = self.features(x)

  7. return x