pytorch搭建网络,保存参数,恢复参数

这是看过莫凡python的学习笔记。

搭建网络,两种方式

(1)建立Sequential对象

import torch
net = torch.nn.Sequential(
            torch.nn.Linear(2,10),
            torch.nn.ReLU(),
            torch.nn.Linear(10,2))

输出网络结构

Sequential(
  (0): Linear(in_features=2, out_features=10, bias=True)
  (1): ReLU()
  (2): Linear(in_features=10, out_features=2, bias=True)
)

(2)建立网络类,继承torch.nn.module

class Net(torch.nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.hidden = torch.nn.Linear(2,10)
        self.predict = torch.nn.Linear(10,2)
    def forward(self,x):
        x = F.relu(self.hidden(x))
        x = self.predict(x)
        return x

输出和上面基本一样,略微不同

Net(
  (hidden): Linear(in_features=2, out_features=10, bias=True)
  (predict): Linear(in_features=10, out_features=2, bias=True)
)

保存模型,两种方式

(1)保存整个网络,及网络参数

torch.save(net,'net.pkl')

(2)只保存网络参数

torch.save(net.state_dict(),'net_params.pkl')

恢复模型,两种方式

(1)加载整个网络,及参数

net2 = torch.load('net.pkl')

(2)加载参数,但需实现网络

net3 = torch.nn.Sequential(
            torch.nn.Linear(2,10),
            torch.nn.ReLU(),
            torch.nn.Linear(10,2))
net3.load_state_dict(torch.load('net_params.pkl'))