请帮我写一个训练pytorch模型的标准接口函数,输入一个nn.modules类派生的模型,一个dataloader类的输入,和相关的其他参数
时间: 2024-03-18 17:42:33 浏览: 162
Pytorch 模型训练实用教程 代码免费下载
4星 · 用户满意度95%
当然可以,下面是一个标准的训练函数,可以完成你想要的任务:
```python
import torch
import torch.nn as nn
import torch.optim as optim
def train(model, dataloader, criterion, optimizer, device, num_epochs):
model = model.to(device)
best_acc = 0.0
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch+1, num_epochs))
print('-' * 10)
running_loss = 0.0
running_corrects = 0.0
for inputs, labels in dataloader:
inputs = inputs.to(device)
labels = labels.to(device)
optimizer.zero_grad()
with torch.set_grad_enabled(True):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
running_corrects += torch.sum(preds == labels.data)
epoch_loss = running_loss / len(dataloader.dataset)
epoch_acc = running_corrects.double() / len(dataloader.dataset)
print('Loss: {:.4f} Acc: {:.4f}'.format(epoch_loss, epoch_acc))
if epoch_acc > best_acc:
best_acc = epoch_acc
return model
```
该函数接收以下参数:
- `model`:要训练的 PyTorch 模型,必须是 `nn.modules` 类的子类。
- `dataloader`:用于训练模型的 PyTorch 数据加载器。
- `criterion`:损失函数,可以从 `torch.nn` 中选择。
- `optimizer`:优化器,可以从 `torch.optim` 中选择。
- `device`:设备名称,例如 `"cuda:0"` 或 `"cpu"`。
- `num_epochs`:训练的 epoch 数量。
该函数会将模型移动到指定的设备上,并迭代训练数据加载器中的所有批次,计算损失和精度,并更新参数。最后,该函数将返回一个经过训练的 PyTorch 模型。
阅读全文