torch.nn.DataParallel参数
时间: 2024-05-06 12:19:43 浏览: 285
torch.nn.DataParallel是一个用于分布式训练的PyTorch函数,它的参数如下:
- module (nn.Module):需要进行分布式训练的模型。
- device_ids (list of int):用于指定使用哪些GPU进行训练,例如[0, 1, 2]表示使用GPU0、GPU1和GPU2进行训练。
- output_device (int):指定模型输出的设备,默认为device_ids[0]。
示例:
```python
import torch
import torch.nn as nn
# 定义模型
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.linear1 = nn.Linear(10, 20)
self.linear2 = nn.Linear(20, 5)
def forward(self, x):
x = self.linear1(x)
x = self.linear2(x)
return x
model = MyModel()
# 分布式训练
device_ids = [0, 1, 2]
model = nn.DataParallel(model, device_ids=device_ids)
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
# 训练模型
inputs = torch.randn(100, 10)
labels = torch.randn(100, 5)
outputs = model(inputs)
loss = nn.MSELoss()(outputs, labels)
loss.backward()
optimizer.step()
```
阅读全文