基于pytorch神经网络模型参数的加载及自定义

最近在训练MobileNet时经常会对其模型参数进行各种操作,或者替换其中的几层之类的,故总结一下用到的对神经网络参数的各种操作方法。

1.将matlab的.mat格式参数整理转换为tensor类型的模型参数

import torch
import torch.nn as nn
import torch.nn.functional as F
import scipy.io as scio
import os
import numpy as np
from config import Config
import json
config = Config()

Mul = Config.MUL.astype(\'float32\')
Shift = Config.SHIFT.astype(\'float32\')

def load_json(j_fn):
    with open(j_fn,\'r\') as f:
        data = json.load(f)
    return data

def save_json(dic,j_fn):
    json_str = json.dumps(dic)
    with open(j_fn,\'w\') as json_file:
        json_file.write(json_str)

w_dic = {}
b_dic = {}
dic_all = {}
for i in range(1,28,2):
    a = \'w\'+str(i)    #按顺序命名
    b = \'b\'+str(i)
    dic_all[a] = torch.from_numpy(scio.loadmat(config.WEIGHT_PATH + str(i)+\'.mat\')[\'wei\'] * Mul[i-1]/(2**Shift[i-1])).permute(3, 2, 0, 1)
    dic_all[b] = torch.squeeze(torch.from_numpy(scio.loadmat(config.BIAS_PATH + str(i)+\'.mat\')[\'bias\'] * Mul[i-1]/(2**Shift[i-1])))
    # print(a, \'Mul\'+str(i-1))
    if i == 27:
        break
    a = \'w\'+str(i+1)
    b = \'b\'+str(i+1)
    dic_all[a] = torch.from_numpy(scio.loadmat(config.WEIGHT_PATH + str(i+1)+\'.mat\')[\'wei\'] * Mul[i]/(2**Shift[i])).permute(2, 0, 1).unsqueeze(1)
    dic_all[b] = torch.squeeze(torch.from_numpy(scio.loadmat(config.BIAS_PATH + str(i+1)+\'.mat\')[\'bias\'] * Mul[i]/(2**Shift[i])))
#此处由于自己之前的命名问题,中间跳过了28层(池化层),直接按照有参数的层存储了参数,故27后的文件名变成了29
dic_all[\'w29\'] = torch.squeeze(torch.from_numpy(scio.loadmat(config.WEIGHT_PATH + \'29.mat\')[\'wei\'] * Mul[28]/(2**Shift[28])).permute(3, 2, 0, 1)[1:, :])
dic_all[\'b29\'] = torch.squeeze(torch.from_numpy(scio.loadmat(config.BIAS_PATH + \'29.mat\')[\'bias\'] * Mul[28]/(2**Shift[28])))[1:]
#存为.pth文件
param_fn = \'mobilenet_param_float.pth\'
torch.save(dic_all,param_fn)

其中,mul和shift为量化后的乘子和移位参数(如果参数是浮点的则可以忽略这部分),另外,我的weight和bias是按照每层单独存在一个按照层序号命名的.mat文件中。且由于是从matlab的程序得到的,需要对参数的维度进行一下转换(permute()方法),同时对需要增加或减少维度的用unsqueeze()或torch.squeeze()方法进行改变(注意一定要和网络需要的输入维度相同才行)。最后按照原来对参数文件命名的方式保存成字典存成.pth文件(此时的字典还不能直接使用,需要在具体定义的网络中更换想应的key值)。

*另外,代码中用来读取和存储.json文件的函数可以忽略,在这里没有用到

2.将自定义网络的参数替换成自己需要的(DIY模型参数)

from mobilenet_v1 import MobileNet_v1
import torch
from config import Config
from load_data import loadtestdata
from torch.autograd import Variable
import numpy as np
from Mobilenetv1_quantified import MobileNet, MobileNet_Bayer
import json
import matplotlib.pyplot as plt
import numpy as np
import torchvision

param_keys = [\'w1\', \'b1\', \'w2\', \'b2\', \'w3\', \'b3\', \'w4\', \'b4\', \'w5\', \'b5\', \'w6\', \'b6\', \'w7\', \'b7\', \'w8\', \'b8\', \'w9\', \'b9\', \'w10\', \'b10\', \'w11\', \'b11\', \'w12\', \'b12\', \'w13\', \'b13\', \'w14\', \'b14\', \'w15\', \'b15\', \'w16\', \'b16\', \'w17\', \'b17\', \'w18\', \'b18\', \'w19\', \'b19\', \'w20\', \'b20\', \'w21\', \'b21\', \'w22\', \'b22\', \'w23\', \'b23\', \'w24\', \'b24\', \'w25\', \'b25\', \'w26\', \'b26\', \'w27\', \'b27\', \'w29\', \'b29\']
file_name = \'/home/wangshuyu/MobileNet_v1/mobilenet_param_float.pth\'
dic_param = torch.load(file_name)      # 此处打开上一步存成的参数字典(按照每一层的权重、偏置的顺序)
Model = MobileNet()                    # 实例化预定义的MobileNet网络(网络结构将在其他文中给出)
net_dic = Model.state_dict()           # 加载预定义网络的参数字典,用来获取网络的键值
for i, param_tensor in enumerate(net_dic ,0):
    net_dic[param_tensor] = dic_param[param_keys[i]]
    # print(i,\'\t\',param_tensor ,net_dic[param_tensor].shape)   #可以用来查看参数的维度
param_fn = \'MobileNet_float.pth\'
torch.save(net_dic,param_fn)
# 下面开始是自己定义的另一个网络,只需要固定MobileNet其中一部分参数,剩下的部分参数用来训练,因此只从第11个之后的开始取参数
model2 = MobileNet_Bayer()
dic2 = model2.state_dict()
key_list = list(net_dic.keys())
for i, param_tensor in enumerate(dic2 ,0):
    if i > 11:
        dic2[param_tensor] = (net_dic[key_list[i - 2]])
    print(i, \'\t\', param_tensor, dic2[param_tensor].shape)
param_fn2 = \'MobileNet_Bayer.pth\'
torch.save(dic2,param_fn2)

这里主要实现了将之前存好的量化后的mobilenet每层参数根据自己定义的网络构建了参数字典,在训练或测试的时候,只需要加载之前存好的预训练参数就可以了:

from Mobilenetv1_quantified import MobileNet
import torch
from load_data import loadtestdata Net = MobileNet() param_dic = torch.load(\'MobileNet_float.pth\') Net.load_state_dict(param_dic)
classes = range(0,1000)
test_data = loadtestdate()
test(test_data, Net, classes)