多机多卡训练模型案例
时间: 2025-01-31 21:45:28 浏览: 23
多机多卡环境下的模型训练
在分布式环境中进行多机多卡训练能够显著提升大规模数据集上的模型训练效率。PyTorch 提供了 torch.distributed
模块来支持这种类型的并行计算[^1]。
使用 PyTorch 的 Distributed Data Parallel (DDP)
相比于传统的 DataParallel
方式,DistributedDataParallel
更适合处理大型集群中的高效通信需求。通过集成 NVIDIA 的 NCCL 库作为后端,可以在 GPU 之间实现高效的集体通信原语操作,如广播、全规约等[^3]。
下面是一个简单的 DDP 训练脚本的例子:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
def setup(rank, world_size):
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12355'
# 初始化进程组
dist.init_process_group("nccl", rank=rank, world_size=world_size)
class ToyModel(nn.Module):
def __init__(self):
super().__init__()
layers = []
current_dim = 784 # MNIST 图像大小为 28*28
for _ in range(2):
next_dim = max(current_dim // 2, 10)
layers.append(nn.Linear(current_dim, next_dim))
layers.append(nn.ReLU())
current_dim = next_dim
self.network = nn.Sequential(*layers, nn.Linear(next_dim, 10))
def forward(self, x):
batch_size = x.size()[0]
return self.network(x.view(batch_size, -1))
def train_model(rank, size):
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
dataset_train = datasets.MNIST('.', download=True, transform=transform, train=True)
sampler = torch.utils.data.DistributedSampler(dataset_train, num_replicas=size, rank=rank)
dataloader = DataLoader(dataset_train, batch_size=64, shuffle=False, sampler=sampler)
model = ToyModel().to(rank)
ddp_model = DDP(model, device_ids=[rank])
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(ddp_model.parameters())
for epoch in range(10):
running_loss = 0.0
for i, data in enumerate(dataloader, start=0):
inputs, labels = data[0].to(rank), data[1].to(rank)
outputs = ddp_model(inputs)
loss = loss_fn(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
running_loss += loss.item()
if __name__ == "__main__":
world_size = 4 # 假设有四张显卡可用
mp.spawn(train_model,
args=(world_size,),
nprocs=world_size,
join=True)
此代码展示了如何设置一个多节点、多GPU的训练过程,并利用MNIST手写数字识别任务进行了简化说明。实际应用中可能还需要考虑更多细节配置以及性能优化措施。
相关推荐



















