pytorch输出模型性能的代码
时间: 2023-06-05 12:06:12 浏览: 255
以下是在PyTorch中输出模型性能的代码示例:
``` python
import torch
# 计算模型在测试集上的准确率
def test(model, test_data_loader):
model.eval()
num_correct = 0
num_samples = 0
with torch.no_grad():
for data, targets in test_data_loader:
data = data.cuda()
targets = targets.cuda()
scores = model(data)
_, predictions = scores.max(1)
num_correct += (predictions == targets).sum()
num_samples += predictions.size(0)
accuracy = float(num_correct) / float(num_samples) * 100
print('Test Accuracy: {:.2f}%'.format(accuracy))
# 训练模型
# ...
# 在测试集上测试模型的性能
test(model, test_data_loader)
```
需要注意的是,这里的`model`是一个PyTorch模型,`test_data_loader`是一个测试数据集的`DataLoader`对象。如果你想输出其他性能指标,可以参考PyTorch的文档进行操作。
阅读全文