PyTorch 实战,模型训练、模型加载、模型测试

本次将一个使用Pytorch的一个实战项目,记录流程:自定义数据集->数据加载->搭建神经网络->迁移学习->保存模型->加载模型->测试模型

自定义数据集

参考我的上一篇博客:自定义数据集处理

数据加载

默认小伙伴有对深度学习框架有一定的了解,这里就不做过多的说明了。

好吧,还是简单的说一下吧:

我们在做好了自定义数据集之后,其实数据的加载和MNSIT 、CIFAR-10 、CIFAR-100等数据集的都是相似的,过程如下所示:

导入必要的包

import torch

from torch import optim, nn

import visdom

from torch.utils.data import DataLoader

1

2

3

4

加载数据

可以发现和MNIST 、CIFAR的加载基本上是一样的

train_db = Pokemon('pokeman', 224, mode='train')

val_db = Pokemon('pokeman', 224, mode='val')

test_db = Pokemon('pokeman', 224, mode='test')

train_loader = DataLoader(train_db, batch_size=batchsz, shuffle=True,

num_workers=4)

val_loader = DataLoader(val_db, batch_size=batchsz, num_workers=2)

test_loader = DataLoader(test_db, batch_size=batchsz, num_workers=2)

1

2

3

4

5

6

7

搭建神经网络

ResNet-18网络结构:

在这里插入图片描述

ResNet全名Residual Network残差网络。Kaiming He 的《Deep Residual Learning for Image Recognition》获得了CVPR最佳论文。他提出的深度残差网络在2015年可以说是洗刷了图像方面的各大比赛,以绝对优势取得了多个比赛的冠军。而且它在保证网络精度的前提下,将网络的深度达到了152层,后来又进一步加到1000的深度。论文的开篇先是说明了深度网络的好处:特征等级随着网络的加深而变高,网络的表达能力也会大大提高。因此论文中提出了一个问题:是否可以通过叠加网络层数来获得一个更好的网络呢?作者经过实验发现,单纯的把网络叠起来的深层网络的效果反而不如合适层数的较浅的网络效果。因此何恺明等人在普通平原网络的基础上增加了一个shortcut, 构成一个residual block。此时拟合目标就变为F(x),F(x)就是残差:

在这里插入图片描述

训练模型

def evalute(model, loader):

model.eval()

correct = 0

total = len(loader.dataset)

for x, y in loader:

x, y = x.to(device), y.to(device)

with torch.no_grad():

logits = model(x)

pred = logits.argmax(dim=1)

correct += torch.eq(pred, y).sum().float().item()

return correct / total

def main():

model = ResNet18(5).to(device)

optimizer = optim.Adam(model.parameters(), lr=lr)

criteon = nn.CrossEntropyLoss()

best_acc, best_epoch = 0, 0

global_step = 0

viz.line([0], [-1], win='loss', opts=dict(title='loss'))

viz.line([0], [-1], win='val_acc', opts=dict(title='val_acc'))

for epoch in range(epochs):

for step, (x, y) in enumerate(train_loader):

x, y = x.to(device), y.to(device)

model.train()

logits = model(x)

loss = criteon(logits, y)

optimizer.zero_grad()

loss.backward()

optimizer.step()

viz.line([loss.item()], [global_step], win='loss', update='append')

global_step += 1

if epoch % 1 == 0:

val_acc = evalute(model, val_loader)

if val_acc > best_acc:

best_epoch = epoch

best_acc = val_acc

viz.line([val_acc], [global_step], win='val_acc', update='append')

print('best acc:', best_acc, 'best epoch:', best_epoch)

model.load_state_dict(torch.load('best.mdl'))

print('loaded from ckpt!')

test_acc = evalute(model, test_loader)

迁移学习

提升模型的准确率:

# model = ResNet18(5).to(device)

trained_model=resnet18(pretrained=True) # 此时是一个非常好的model

model = nn.Sequential(*list(trained_model.children())[:-1], # 此时使用的是前17层的网络 0-17 *:随机打散

Flatten(),

nn.Linear(512,5)

).to(device)

# x=torch.randn(2,3,224,224)

# print(model(x).shape)

optimizer = optim.Adam(model.parameters(), lr=lr)

criteon = nn.CrossEntropyLoss()

1

2

3

4

5

6

7

8

9

10

11

保存、加载模型

pytorch保存模型的方式有两种:

第一种:将整个网络都都保存下来

第二种:仅保存和加载模型参数(推荐使用这样的方法)

# 保存和加载整个模型

torch.save(model_object, 'model.pkl')

model = torch.load('model.pkl')

1

2

3

# 仅保存和加载模型参数(推荐使用)

torch.save(model_object.state_dict(), 'params.pkl')

model_object.load_state_dict(torch.load('params.pkl'))

1

2

3

可以看到这是我保存的模型:

其中best.mdl是第二中方法保存的

model.pkl则是第一种方法保存的

在这里插入图片描述

测试模型

这里是训练时的情况

在这里插入图片描述

看这个数据准确率还是不错的,但是还是需要实际的测试这个模型,看它到底学到东西了没有,接下来简单的测试一下:

import torch

from PIL import Image

from torchvision import transforms

device = torch.device('cuda')

transform=transforms.Compose([

transforms.Resize(256),

transforms.CenterCrop(224),

transforms.ToTensor(),

transforms.Normalize(mean=[0.485,0.456,0.406],

std=[0.229,0.224,0.225])

])

def prediect(img_path):

net=torch.load('model.pkl')

net=net.to(device)

torch.no_grad()

img=Image.open(img_path)

img=transform(img).unsqueeze(0)

img_ = img.to(device)

outputs = net(img_)

_, predicted = torch.max(outputs, 1)

# print(predicted)

print('this picture maybe :',classes[predicted[0]])

if __name__ == '__main__':

prediect('./test/name.jpg')

实际的测试结果:

在这里插入图片描述

在这里插入图片描述

效果还是可以的,完整的代码:

https://github.com/huzixuan1/Loader_DateSet

数据集下载链接:

https://pan.baidu.com/s/12-NQiF4fXEOKrXXdbdDPCg