pytorch怎么使用主函数
时间: 2024-02-17 17:29:01 浏览: 19
在 PyTorch 中,通常使用主函数来组织代码并运行模型的训练和测试。以下是一个简单的示例:
```python
import torch
def main():
# 加载数据集
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
# 初始化模型、损失函数和优化器
model = MyModel()
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# 训练模型
for epoch in range(10):
for input, target in train_loader:
output = model(input)
loss = criterion(output, target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 测试模型
correct = 0
total = 0
with torch.no_grad():
for input, target in test_loader:
output = model(input)
_, predicted = torch.max(output.data, 1)
total += target.size(0)
correct += (predicted == target).sum().item()
accuracy = 100 * correct / total
print(f'Test accuracy: {accuracy:.2f}%')
if __name__ == '__main__':
main()
```
在这个示例中,我们首先加载了训练和测试数据集,并初始化了模型、损失函数和优化器。接下来,我们使用一个 for 循环来迭代训练数据集,并在每个迭代中计算模型的输出、损失和梯度,并使用优化器来更新模型参数。在训练完成后,我们使用测试数据集来评估模型的准确性。最后,我们输出测试准确率作为结果。
在主函数中,我们可以执行任意的 PyTorch 操作,并使用 Python 的控制流语句来控制模型的训练和测试过程。通过组织代码并将其放在主函数中,我们可以更好地管理代码,并使其更易于维护和扩展。