准确率计算pytorch
时间: 2023-11-15 21:58:58 浏览: 43
Pytorch中可以使用混淆矩阵来计算准确率、误判率和召回率。混淆矩阵是一个表格,用于比较模型预测结果和真实结果之间的差异。在混淆矩阵中,行表示真实结果,列表示预测结果。对角线上的数字表示正确分类的数量,非对角线上的数字表示错误分类的数量。通过混淆矩阵,可以计算出准确率、误判率和召回率等指标。
下面是一个计算准确率的Pytorch代码示例:
```python
from sklearn.metrics import confusion_matrix
# 定义模型和数据集
model = ...
test_loader = ...
# 测试模型
model.eval()
y_true = []
y_pred = []
with torch.no_grad():
for data, target in test_loader:
output = model(data)
pred = output.argmax(dim=1, keepdim=True)
y_true.extend(target.tolist())
y_pred.extend(pred.tolist())
# 计算混淆矩阵
conf_mat = confusion_matrix(y_true, y_pred)
# 计算准确率
accuracy = conf_mat.diagonal().sum() / conf_mat.sum()
print('Accuracy:', accuracy)
```
在上面的代码中,首先定义了模型和数据集,然后使用模型对测试集进行预测,并将预测结果和真实结果保存在`y_pred`和`y_true`列表中。接着,使用`confusion_matrix`函数计算混淆矩阵,最后通过混淆矩阵计算准确率。