pytorch单机多卡并行计算示例 - 那抹阳光1994

pytorch单机多卡并行计算示例

一个简单的例子。

注意:

os.environ[\'MASTER_ADDR\'] = \'xxx.xx.xx.xxx\' # 这里填写电脑的IP地址
os.environ[\'MASTER_PORT\'] = \'29555\' # 空闲端口

这两个参数似乎必须提前给出,选择的初始化方法为init_method="env://"(默认的环境变量方法)

# 单机多卡并行计算示例

import os
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP


# https://pytorch.org/docs/stable/notes/ddp.html


def example(local_rank, world_size): # local_rank由mp.spawn自动给出
    # create default process group
    dist.init_process_group(backend="gloo", init_method="env://", rank=local_rank, world_size=world_size)
    # create local model
    model = nn.Linear(10, 10).cuda(local_rank)
    # construct DDP model
    ddp_model = DDP(model, device_ids=[local_rank], output_device=local_rank)
    # define loss function and optimizer
    loss_fn = nn.MSELoss()
    optimizer = optim.SGD(ddp_model.parameters(), lr=0.001)

    # forward pass
    for i in range(100):
        if local_rank == 0: # 这里开几个进程就会打印几次
            print(i)
        outputs = ddp_model(torch.randn(20, 10).cuda(local_rank))
        labels = torch.randn(20, 10).cuda(local_rank)
        # backward pass
        loss_fn(outputs, labels).backward()
        # update parameters
        optimizer.step()


def main():
    os.environ[\'MASTER_ADDR\'] = \'xxx.xx.xx.xxx\' # 这里填写电脑的IP地址
    os.environ[\'MASTER_PORT\'] = \'29555\' # 空闲端口
    world_size = torch.cuda.device_count()
    mp.spawn(example, args=(world_size,), nprocs=world_size, join=True)



if __name__=="__main__":
    main()
    print(\'Done!\')