from sys import gettrace
from typing import ForwardRef
import torch
from torch.nn.modules import flatten
import torchvision
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
#来自于李沐深度学习课程https://zh-v2.d2l.ai/chapter_preface/index.html
#类别预测层
def cls_predictor(num_inputs,num_anchors,num_classes):
return nn.Conv2d(num_inputs,num_anchors*(num_classes+1),
kernel_size=3,padding=1)
#边界框预测层
def bbox_predictor(num_inputs,num_anchors):
return nn.Conv2d(num_inputs,num_anchors*4,kernel_size=3,padding=1)
#连接多尺度预测层
def forward(x,block):
return block(x)
# Y1 = forward(torch.zeros((2,8,20,20)),cls_predictor(8,5,10))
def flatten_pred(pred):
return torch.flatten(pred.permute(0,2,3,1),start_dim=1)
def concat_preds(preds):
return torch.cat([flatten_pred(p) for p in preds],dim=1)
#高宽减半
def down_sample_blk(in_channels,out_channels):
blk=[]
for _ in range(2):
blk.append(nn.Conv2d(in_channels,out_channels,
kernel_size=3,padding=1))
blk.append(nn.BatchNorm2d(out_channels))
blk.append(nn.ReLU())
in_channels = out_channels
blk.append(nn.MaxPool2d(2))
return nn.Sequential(*blk)
#基本网络块,用于从输入图像中抽取特征,网络块输出的特征图为:32*32(256/(2*2*2)=32)
def base_net():
blk=[]
num_filters = [3,16,32,64]
for i in range(len(num_filters)-1):
blk.append(down_sample_blk(num_filters[i],num_filters[i+1]))
return nn.Sequential(*blk)
def get_blk(i):
if i==0:
blk=base_net()
elif i==1:
blk=down_sample_blk(64,128)
elif i==4:
blk = nn.AdaptiveAvgPool2d((1,1))
else:
blk = down_sample_blk(128,128)
return blk
#为每一个块定义前向计算, 输出包括:特征图、生成的锚框,预测的锚框的类别和偏移量
def blk_forward(X,blk,size,ratio,cls_predictor,bbox_predictor):
Y = blk(X)
anchors = d2l.multibox_prior(Y,sizes=size,ratios=ratio)
cls_preds = cls_predictor(Y)
bbox_preds = bbox_predictor(Y)
return(Y,anchors,cls_preds,bbox_preds)
#定义完整的模型
class TinySSD(nn.Module):
def __init__(self,num_classes,**kwargs):
super(TinySSD,self).__init__(**kwargs)
self.num_classes = num_classes
idx_to_in_channels = [64,128,128,128,128]
for i in range(5):
#即赋值语句
setattr(self,f\'blk_{i}\',get_blk(i))
setattr(self,f\'cls_{i}\',cls_predictor(idx_to_in_channels[i],
num_anchors,num_classes))
setattr(self,f\'bbox_[i]\',bbox_predictor(idx_to_in_channels[i],
num_anchors))
def forward(self,X):
anchors,cls_preds,bbox_preds = [None]*5,[None]*5,[None]*5
for i in range(5):
X,anchors[i],cls_preds[i],bbox_preds[i] = blk_forward(
X,getattr(self,f\'blk_{i}\'),sizes[i],ratios[i],
getattr(self,f\'cls_{i}\'),getattr(self,f\'bbox_{i}\')
)
anchors = torch.cat(anchors,dim=1)
cls_preds = concat_preds(cls_preds)
cls_preds = cls_preds.reshape(
cls_preds.shape[0],-1,self.num_classes+1)
bbox_preds = concat_preds(bbox_preds)
return anchors,cls_preds,bbox_preds
#训练模型
#读取数据集和初始化
batch_size = 32
train_iter,_=d2l.load_data_bananas(batch_size)
device,net = d2l.try_gpu(),TinySSD(num_classes=1)
trainer = torch.optim.SGD(net.parameters(),lr=0.2,weight_decay=5e-4)
#定义损失函数和评价函数
cls_loss = nn.CrossEntropyLoss(reduction=\'none\')
bbox_loss = nn.L1Loss(reduction=\'none\')
def calc_loss(cls_preds,cls_labels,bbox_preds,bbox_labels,bbox_masks):
batch_size,num_classes = cls_preds.shape[0],cls_preds.shape[2]
cls = cls_loss(cls_preds.reshape(-1,num_classes),
cls_labels.reshape(-1)).reshape(batch_size,-1).mean(dim=1)
bbox = bbox_loss(bbox_preds*bbox_masks).mean(dim=1)
return cls+bbox
#使用平均绝对误差来评价边界框的预测结果
def cls_eval(cls_preds,cls_labels):
return float((cls_preds.argmax(dim=-1).type(cls_preds.dtype)==cls_labels).sum())
def bbox_eval(bbox_preds,bbox_labels,bbox_masks):
return float((torch.abs((bbox_labels-bbox_preds)*bbox_masks)).sum())
#训练模型
num_epochs,timer = 20,d2l.Timer()
animator = d2l.Animator(xlabel=\'epoch\',xlim=[1,num_epochs],
legend=[\'class error\',\'bbox mae\'])
net = net.to(device)
for epoch in range(num_epochs):
metric = d2l.Accumulator(4)
net.train()
for features ,target in train_iter:
timer.start()
trainer.zero_grad()
X,Y = features.to(device),target.to(device)
anchors,cls_preds,bbox_preds = net(X)
bbox_labels,bbox_masks,cls_labels = d2l.multibox_target(anchors,Y)
l = calc_loss(cls_preds,cls_labels,bbox_preds,bbox_labels,bbox_masks)
l.mean().backward()
trainer.step()
metric.add(cls_eval(cls_preds,cls_labels),cls_labels.numel(),
bbox_eval(bbox_preds,bbox_labels,bbox_masks),bbox_labels.numel())
cls_err,bbox_mae = 1-metric[0]/metric[1],metric[2]/metric[3]
animator.add(epoch+1,(cls_err,bbox_mae))
#预测目标
X = torchvision.io.read_image(\'../img/banana.jpg\').unsqueeze(0).float()
img = X.squeeze(0).permute(1,2,0).long()
def predict(X):
net.eval()
anchors,cls_preds,bbox_preds = net(X.to(device))