pytorch 网络可视化

今天使用hiddenlayer测试了下retinanet网络的可视化。

首先,安装hiddlayer,直接pip pip install git+https://github.com/waleedka/hiddenlayer.git

然后在终端加载模型并显示:

import model, torch
import hiddenlayer as hl

retinanet = model.resnet18(num_classes=100, pretrained=True).cuda()
x = torch.rand((1, 3, 224, 224)).cuda().float()
ann = torch.tensor([[[20.0, 30.0, 53.2, 33.3, 32.0]]]).cuda().float()
hl.build_graph(retinanet, [x, ann])
hl.save('/home/willer/model.pdf')

模型太复杂了,放在这里了。

昨天晚上对比着模型结构的pdf和代码又看了下,发现还是很有用的,起码对网络的数据流动的认识更加清晰了。