PyTorch 介绍 | 保存和加载模型

本节我们将会看到如何保存模型状态、加载和运行模型预测

import torch
import torchvision.models as models

保存和加载模型权重

PyTorch模型在一个称为 state_dict 的内部状态字典内保存了学习的参数,可以通过 torch.save实现这一过程。

model = models.vgg16(pretrained=True)
torch.save(model.state_dict(), 'model_weights.pth')

为了加载模型参数,你需要首先创建一个相同模型的实体,然后使用 load_state_dict()加载参数。

model = models.vgg16() # we do not specify pretrained=True, i.e. do not load default weights
model.load_state_dict(torch.load('model_weights.pth'))
model.eval()

注意:在推理前,确保调用 model.eval() 设置dropout和batch normalization层是评估模式,否则将产生不一致的推断结果。

使用Shapes保存和加载模型

当加载模型权重时,我们需要首先初始化模型类,因为该类定义了网络结构。我们可能想将模型权重和该类的结构保存在一起,在这种情况下,可以将 model (而不是model.state_dict())传入保存函数。

torch.save(model, 'model.pth')

加载

model = torch.load('model.pth')

注意:这种方法在序列化模型时使用Python pickle模块,因此,它依赖于加载模型时可用的实际类的定义。

相关教程

Saving and Loading a General Checkpoint in PyTorch