Pytorch 模型的存储与加载

本文主要内容来自Pytorch官方文档推荐的一篇英文博客, 本文主要介绍了在Pytorch中模型的存储方法, 以及存储形式, 以及Pytorch存储模型正真存储的是模型的什么结构. 以及加载模型的时候, 模型的哪些数据会被加载. 以及加载后的形式.

首先大致讲下三个最主要的函数的功能:

torch.save: 将序列化的对象存储到硬盘中.此函数使用Python的pickle实用程序进行序列化. 对于数据类型都可以进行序列化存储, 模型, 张量, 以及字典, 等各种数据对象都可以使用该函数存储.

torch.load: 该函数使用的是 pickle 的阶序列化过程, 并将结果存如内存中, 该函数也促进设备加载数据.

torch.nn.Module.load_state_dict: 使用反序列化的 state_dict 加载模型的参数字典

模型的加载

state_dict 是什么

在一个Pytorch模型中, 通常是 torch.nn.module , 模型中可学习的参数被包含在模型的参数中, 通常是可以使用 model.parameters() 函数访问, 通常都是使用该方法访问的. state_dict只是一个Python字典对象,它将每个图层映射到其参数张量, 这个字典的 key 是图层的 'name', 注意, 只有该层有可学习的参数的层, 也就是可以通过反向传播优化的层, 以及 registered buffers (batchnorm’s running_mean) 才会在 state_dict 中有存储条目. 优化器对象(torch.optim)也具有state_dict,其中包含有关优化器状态以及所用超参数的信息. state_dict 的本质是对模型进行了字典化.

state_dict的字典形式使得对模型的操作更加的灵活, 例如直接导出模型, 修改其中的参数信息, 或者对层数进行修改等, 然后继续将模型保留. 还是使用一个简单的模型举个例子:

class TheModelClass(nn.Module):
    def __init__(self):
        super(TheModelClass, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        # 这里卷积核的大小是 5, 个数是 6, 输入的 width 是 3
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # 两次卷积的结果应该是 5x5x16 的矩阵
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        # 可以看出网络层的结构, 两个卷积层, 其余还有全连接层

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

# Initialize model
model = TheModelClass()

# Initialize optimizer
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

可以得到模型的输出为:

Model's state_dict:
conv1.weight     torch.Size([6, 3, 5, 5])
conv1.bias   torch.Size([6])
conv2.weight     torch.Size([16, 6, 5, 5])
conv2.bias   torch.Size([16])
fc1.weight   torch.Size([120, 400])
fc1.bias     torch.Size([120])
fc2.weight   torch.Size([84, 120])
fc2.bias     torch.Size([84])
fc3.weight   torch.Size([10, 84])
fc3.bias     torch.Size([10])

Optimizer's state_dict:
state    {}
param_groups     [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]

模型的参数的输出是字典的键值对, 后面是优化参数的输出, 也是键值对

存储与加载模型对应的形式

使用 state_dict 存储与加载模型

save:

torch.save(model.state_dict(), PATH)

Load 模型:

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

从模型存储的角度, 存储模型的时候, 唯一需要存储的是该模型训练的参数, torch.save() 函数也可以存储模型的 state_dict. 使用该方法进行存储, 模型被看做字典形式, 所以对模型的操作更加灵活. 在这种形式下常见的PyTorch约定是使用.pt或.pth文件扩展名保存模型.

注意, 加载模型之后, 并不能直接运行, 需要使用 model.eval() 函数设置 Dropout 与层间正则化. 另一方面, 该方法在存储模型的时候是以字典的形式存储的, 也就是存储的是模型的字典数据, Pytorch 不能直接将模型读取为该形式, 必须先 torch.load() 该模型, 然后再使用 load_state_dict().

将模型作为整体存储与加载

Save:

torch.save(model, PATH)

Load:

# Model class must be defined somewhere
model = torch.load(PATH)
model.eval()

使用该方法相当于跳过了对模型的 state_dict 描述的过程, 而是直接使用 python 的 pickle 包, 这种方法的缺点是, 模型的存储形式与加载形式十分固定, 这样做的原因是因为pickle不会保存模型类本身. 而是存出来包含该文件的路径,该路径在加载时使用. 因此,在其他项目中使用或重构后,代码可能会以各种方式中断. 但是这种方法存储的文件的类型与前面的方法一样. 同样, 以该方法加载模型运行之前需要调用 model.eval() .

存储与加载一般的 Checkpoint

Save:

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...
            }, PATH)

Load:

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

model.eval()
# - or -
model.train()

可以看出 checkpoints 是模型主要内容的一个字典, 基本包含了模型各种数据, 例如上面的例子模型的参数使用的是 optimizer.state_dict().

存储 checkpoints 主要目的是为了方便加载模型继续训练, 将所有的信息存储, 加载模型继续训练的时候就会更加方便. 为了存储一个训练过程的多种信息, 最好的方式是使用 dictionary 进行序列化, 这样存储一个训练模型的形式是 .tar, 要加载项目,首先初始化模型和优化器,然后使用torch.load() 在本地加载字典.从这里开始, 只需按期望查询字典即可轻松访问已保存的项目. 请记住,在运行推理之前,必须调用model.eval() 来将 Dropout 和 Batch 正则化设置为评估模式, 不这样做将产生不一致的推断结果. 如果恢复训练,那么调用model.train() 以确保这些层处于训练模式.

在一个文件中存储多个模型

save:

torch.save({
            'modelA_state_dict': modelA.state_dict(),
            'modelB_state_dict': modelB.state_dict(),
            'optimizerA_state_dict': optimizerA.state_dict(),
            'optimizerB_state_dict': optimizerB.state_dict(),
            ...
            }, PATH)

Load:

modelA = TheModelAClass(*args, **kwargs)
modelB = TheModelBClass(*args, **kwargs)
optimizerA = TheOptimizerAClass(*args, **kwargs)
optimizerB = TheOptimizerBClass(*args, **kwargs)

checkpoint = torch.load(PATH)
modelA.load_state_dict(checkpoint['modelA_state_dict'])
modelB.load_state_dict(checkpoint['modelB_state_dict'])
optimizerA.load_state_dict(checkpoint['optimizerA_state_dict'])
optimizerB.load_state_dict(checkpoint['optimizerB_state_dict'])

modelA.eval()
modelB.eval()
# - or -
modelA.train()
modelB.train()

保存包含多个 torch.nn.Modules 的模型(例如GAN,序列到序列模型或模型集合)时,将采用与保存常规检查点相同的方法。 换句话说,保存每个模型的state_dict和相应的优化器的字典. 如前所述,您可以保存任何其他可以帮助您恢复培训的项目,只需将它们添加到字典中即可. 使用该方法存储的文件也是 .tar 形式的, 要加载模型,请首先初始化模型和优化器,然后使用torch.load()在本地加载字典。 从这里,您只需按期望查询字典即可轻松访问已保存的项目.

跨平台模型保存与加载

GPU 到 CPU

Save:

torch.save(model.state_dict(), PATH)

Load:

device = torch.device('cpu')
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location=device))
Save on GPU, Load on GPU

Save:

torch.save(model.state_dict(), PATH)

Load:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model
Save on CPU, Load on GPU

Save:

torch.save(model.state_dict(), PATH)

Load:

device = torch.device("cuda")
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH, map_location="cuda:0"))  # Choose whatever GPU device number you want
model.to(device)
# Make sure to call input = input.to(device) on any input tensors that you feed to the model

加载部分模型

举个例子:

# build
        encoder = TransformerModel(params, dico, is_encoder=True, with_output=False)  # TODO: only output when necessary - len(params.clm_steps + params.mlm_steps) > 0
        decoder = TransformerModel(params, dico, is_encoder=False, with_output=True)
        
# reload pretrained word embeddings
        if params.reload_emb != '':
                # 表示加载预训练模型
            word2id, embeddings = load_embeddings(params.reload_emb, params)
            set_pretrain_emb(encoder, dico, word2id, embeddings)
            set_pretrain_emb(decoder, dico, word2id, embeddings)
            set_pretrain_emb(model2, dico, word2id, embeddings)

        # reload a pretrained model
        if params.reload_model != '':
            enc_path, dec_path = params.reload_model.split(',')
            assert not (enc_path == '' and dec_path == '')

            # reload encoder
            if enc_path != '':
                enc_reload = torch.load(enc_path, map_location=lambda storage, loc: storage.cuda(params.local_rank))
                # 预训练模型是在 GPU 上训练的
                enc_reload = enc_reload['model' if 'model' in enc_reload else 'encoder']
                # 导入存储的文件的模型
                if all([k.startswith('module.') for k in enc_reload.keys()]):
                    enc_reload = {k[len('module.'):]: v for k, v in enc_reload.items()}
                # 这个过程相当于将model 反序列化为 state_dict的形式
                encoder.load_state_dict(enc_reload, strict=False)
                # 这个后面的 strict=False 就是对 encoder 与 enc_reload.state_dict之间差异进行处理, 如果encoder 的模型结构与 enc_reload模型结构
                # 不一样的时候, 就会向 encoder 转化, 也就是 encoder 不包含的层就不会导入, 例如这里 enc_reload 就是一个完整的 Transformer 模型, 但是
                # encoder 是不包含输出部分的, 所以就不会加载这部分

对于该部分, 本文只是做了个简单的例子介绍, 更详细的内容参见传送门 . 对于这个传送门的例子, 如果我们先存储一个大模型, 将大模型加载到小模型的时候, 使用:

path = 'xxx.pth'
model = Net()
model.load_state_dict(t.load(path), strict=False)
for module in model.named_modules():
    print(module)
for name, param in model.named_parameters():
    print(name, param)

从输出可以看出模型向下兼容,