pytorch ddp mp.spawn 启动多机多卡训练代码示例
时间: 2023-10-24 08:34:38 浏览: 156
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def run(rank, size):
# 初始化进程组
dist.init_process_group("nccl", rank=rank, world_size=size)
# 定义模型和优化器
model = torch.nn.Linear(10, 1).to(rank)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 分布式同步模型参数
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
# 定义数据
data = torch.randn(1000, 10).to(rank)
target = torch.randn(1000, 1).to(rank)
# 训练模型
for epoch in range(10):
optimizer.zero_grad()
output = model(data)
loss = torch.nn.functional.mse_loss(output, target)
loss.backward()
optimizer.step()
print(f"Rank:{rank}, Epoch:{epoch}, Loss:{loss.item()}")
# 释放进程组
dist.destroy_process_group()
if __name__ == "__main__":
# 启动进程
mp.spawn(run, args=(4, ), nprocs=4, join=True)
阅读全文