Pytorch 训练框架,日志管理,可视化

torchfurnace

torchfurnace 是一个集快速训练模型,日志管理,模型checkpoints管理,tensorboard可视化, I/O 加速,模型大小统计于一身的工具包。

使用这个工具包可以快速构建一个深度学习训练,不需要自己写各种训练逻辑,对于已经定义好的模型也不需要修改,

可以说是拿来即用

使用: pip install torchfurnace

github: https://github.com/tianyu-su/torchfurnace

下面的例子是快速搭建训练,使用 VGG16 训练 CIFIAR10

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from torchvision.datasets import CIFAR10
from torch.optim.lr_scheduler import MultiStepLR
from torchfurnace import Engine, Parser
from torchfurnace.utils.function import accuracy

# define training process of your model
class VGGNetEngine(Engine):
    @staticmethod
    def _on_forward(training, model, inp, target, optimizer=None) -> dict:
        ret = {\'loss\': object, \'acc1\': object, \'acc5\': object}
        output = model(inp)
        loss = F.cross_entropy(output, target)

        if training:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        acc1, acc5 = accuracy(output, target, topk=(1, 5))
        ret[\'loss\'] = loss.item()
        ret[\'acc1\'] = acc1.item()
        ret[\'acc5\'] = acc5.item()
        return ret

    @staticmethod
    def _get_lr_scheduler(optim) -> list:
        return [MultiStepLR(optim, milestones=[150, 250, 350], gamma=0.1)]

def main():
    # define experiment name
    parser = Parser(\'TVGG16\')
    args = parser.parse_args()
    experiment_name = \'_\'.join([args.dataset, args.exp_suffix])

    # Data
    ts = transforms.Compose([transforms.ToTensor(), transforms.Normalize(
        (0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    trainset = CIFAR10(root=\'data\', train=True, download=True, transform=ts)
    testset = CIFAR10(root=\'data\', train=False, download=True, transform=ts)

    # define model and optimizer
    net = models.vgg16(pretrained=False, num_classes=10)
    net.avgpool = nn.AvgPool2d(kernel_size=1, stride=1)
    net.classifier = nn.Linear(512, 10)
    optimizer = torch.optim.Adam(net.parameters())

    # new engine instance
    eng = VGGNetEngine(parser,experiment_name)
    acc1 = eng.learning(net, optimizer, trainset, testset)
    print(\'Acc1:\', acc1)

if __name__ == \'__main__\':
    import sys
    run_params = \'--dataset CIFAR10 -lr 0.1 -bs 128 -j 2 --epochs 400 --adjust_lr\'
    sys.argv.extend(run_params.split())
    main()