pytorch 分布式训练

1.初始化进程组

dist.init_process_group(backend="nccl")

backend是后台利用nccl进行通信

2.使样本之间能够进行通信

train_sampler = torch.utils.data.distributed.DistributedSampler(trainset)

3.创建ddp模型

model = DDP(model,device_ids=[local_rank],output_device=local_rank,find_unused_parameters=True)

获得local_rank(在运行launch时会传入一个local_rank参数)

local_rank = torch.distributed.local_rank()

torch.cuda.set_device(local_rank)

运行脚本

CUDA_VISIBLE_DEVICES=$device python -m torch.distributed.launch --nproc_per_node=1 train.py

保存模型(注意只需保存主进程上的模型,保存的是ddp模型的module)

if dist.get_rank() == 0: torch.save(model.module, "model.pth")