1 import torch
2 from torch import optim,nn
3 import visdom
4 import torchvision
5 from torch.utils.data import DataLoader
6
7 from pokemon import Pokemon
8
9 # from resnet import ResNet18
10 # 可以加载直接加载好的状态
11 from torchvision.models import resnet18
12
13 from utils import Flatten
14
15 batchsz = 32
16 lr = 1e-3
17 epochs = 10
18
19 device = torch.device('cuda')
20 # 设置随机种子保证能够复现出来
21 torch.manual_seed(1234)
22
23 train_db = Pokemon('pokemon',224,mode = 'train')
24 val_db = Pokemon('pokemon',224,mode = 'val')
25 test_db = Pokemon('pokemon',224,mode = 'test')
26
27 train_loader = DataLoader(train_db,batch_size = batchsz,shuffle = True,num_workers = 4)
28 val_loader = DataLoader(val_db,batch_size = batchsz,num_workers = 2)
29 test_loader = DataLoader(test_db,batch_size = batchsz,num_workers = 2)
30
31 # visdom
32 viz = visdom.Visdom()
33
34 def evalute(model,loader):
35
36 correct = 0
37 total = len(loader.dataset)
38
39 for x,y in loader:
40 x,y = x.to(device),y.to(device)
41 with torch.no_grad():
42 logits = model(x)
43 pred = logits.argmax(dim = 1)
44 correct += torch.eq(pred,y).sum().float().item()
45
46 return correct / total
47
48 def main():
49
50 # model = ResNet18(5).to(device)
51 trained_model = resnet18(pretrained = True)
52 # 取出前17层,加*打散数据
53 model = nn.Sequential(*list(train_model.children())[:-1], # [b,512,1,1]
54 Flatten(), # [b,512,1,1] --> [b,512]
55 nn.Linear(512,5)
56 ).to(device)
57
58 optimizer = optim.Adam(model.parameters().lr = lr)
59 criteon = nn.CrossEntropyLoss
60
61 best_acc,best_epoch = 0,0
62 global_step = 0
63 # visdom
64 viz.line([0],[-1],win = 'loss',opts = dict(title = 'loss'))
65 viz.line([0],[-1],win = 'val_acc',opts = dict(title = 'val_acc'))
66
67 for epoch in range(epochs):
68
69 for step,(x,y) in enumerate(train_loader):
70
71 # x: [b,3,224,224] ,y : [b]
72 x,y = x.to(device),y.to(device)
73
74 # logits是没经过loss的
75 logits = model(x)
76 # CrossEntropyLoss会在内部进行onehot,所以不需要自己写
77 loss = criteon(logits,y).item()
78
79 optimizer.zero_grad()
80 loss.backward()
81 optimizer.step()
82
83 # visdom
84 viz.line([loss.item()],[global_step],win = 'loss',update = 'append')
85 global_step += 1
86
87 if epoch % 2 == 0:
88
89 val_acc = evalute(model,val_loader)
90
91 if val_acc > best_acc:
92 best_epoch = epoch
93 best_acc = val_acc
94
95 torch.save(model.state_dict(),'best.mdl')
96 # visdom
97 viz.line([val_acc],[global_step],win = 'val_acc',update = 'append')
98
99 print('best_acc:',best_acc,'best_epoch',best_epoch)
100
101 model.load_state_dict(torch.load('best.mdl'))
102 print('loaded from skpt!')
103
104 test_acc = evalute(model,test_loader)
105 print('test_acc',test_acc)
106
107
108 if __name__ == '__main__'
109 main()