使用tensorboardX可视化Pytorch

可视化loss和acc

参考https://www.jianshu.com/p/46eb3004beca

  1. 环境安装:

    conda activate xxx

    pip install tensorboardX

    pip install tensorflow

  2. 代码:

    from tensorboardXimport SummaryWriter
    writer = SummaryWriter('runs/001')
    writer.add_scalar('Train/Loss', train_loss / batch_idx, epoch)
    writer.add_scalar('Train/Acc', 100.0 * correct / total, epoch)
    writer.close()
  3. 服务器:

    conda activate xxx

    tensorboard --logdir=runs/001

  4. 本地:

    终端上输入:ssh -p 22222 -L 6006:localhost:6006 yinwenbin@192.168.2.237

    浏览器上输入:localhost:6006

可视化模型

参考:https://blog.csdn.net/sunqiande88/article/details/80155925?utm_source=copy

import torchvision.models as models

from tensorboardX import SummaryWriter

import torch

model = models.resnet18()

dummy_input = torch.rand(13, 3, 224, 224)

with SummaryWriter(comment='resnet18') as w:

  w.add_graph(model, (dummy_input, ))

conda activate xxx

tensorboard --logdir runs

浏览器上输入:localhost:6006

若提示错误:'torch._C.Value' object has no attribute 'debugName'

修改tensorboardX 1.9为tensorboardX 1.8

参考:https://blog.csdn.net/East_Plain/article/details/103073311