PyTorch 介绍 | 保存和加载模型
本节我们将会看到如何保存模型状态、加载和运行模型预测
import torch
import torchvision.models as models
保存和加载模型权重
PyTorch模型在一个称为 state_dict
的内部状态字典内保存了学习的参数,可以通过 torch.save
实现这一过程。
model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')
为了加载模型参数,你需要首先创建一个相同模型的实体,然后使用 load_state_dict()
加载参数。
model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()
注意:在推理前,确保调用 model.eval()
设置dropout和batch normalization层是评估模式,否则将产生不一致的推断结果。
使用Shapes保存和加载模型
当加载模型权重时,我们需要首先初始化模型类,因为该类定义了网络结构。我们可能想将模型权重和该类的结构保存在一起,在这种情况下,可以将 model
(而不是model.state_dict()
)传入保存函数。
torch.save(model, 'model.pth')
加载
model = torch.load('model.pth')
注意:这种方法在序列化模型时使用Python pickle模块,因此,它依赖于加载模型时可用的实际类的定义。