Pytorch-修改预训练参数

我自己改进的模型为model(model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)),原模型为resnet50。

1.查看模型参数

现模型:

1 model_dict = model.state_dict()
2 for k,v in model_dict.items():
3     print(k)

预训练模型参数

1 pretrained_dict = model_zoo.load_url(model_urls['resnet50'])
2 for k,v in pretrained_dict.items():
3     print(k)

2.将预训练参数赋给自己改进的模型

改进的模型参数和原模型参数一致时:

1 import torch.utils.model_zoo as model_zoo
2 
3 model_urls = {
4     'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth'
5 }
6 
7 model.load_state_dict(model_zoo.load_url(model_urls['resnet50']), strict=False)   

Tip:如果两个模型参数完全一致的话,strict=True,如果两个模型参数不一致的话,当strict=False预训练模型会把具有相同参数名称的值赋给改进的参数,不相同的则不赋值。

改进的模型参数和原模型参数不一致时,使用部分预训练模型参数初始化网络 :

1 model_dict = model.state_dict()          #取出自己模型的网络参数 
2 pretrained_dict = model_zoo.load_url(model_urls['resnet50'])
3 
4 model_dict['classifiers.3.fc.weight'] = pretrained_dict['fc.weight'][:2]
5 model_dict['classifiers.3.fc.bias'] = pretrained_dict['fc.bias'][:2]