基于pytorch调用训练好的五分类模型并写出混淆矩阵代码
时间: 2023-09-11 15:08:36 浏览: 108
假设训练好的五分类模型的输出是一个大小为`(batch_size, 5)`的张量,每一行对应一个样本,每一列表示一个类别的概率,那么可以使用以下代码计算混淆矩阵:
```python
import torch
from sklearn.metrics import confusion_matrix
# 加载模型
model = torch.load('path/to/model')
# 设置模型为评估模式
model.eval()
# 加载测试数据
test_data = ...
# 预测测试数据的标签
with torch.no_grad():
outputs = model(test_data)
predicted_labels = torch.argmax(outputs, dim=1)
# 计算混淆矩阵
true_labels = ...
confusion = confusion_matrix(true_labels, predicted_labels, labels=[0, 1, 2, 3, 4])
print(confusion)
```
在上面的代码中,`confusion_matrix`函数从`sklearn.metrics`模块中导入,用于计算混淆矩阵。首先加载训练好的模型,并将其设置为评估模式。然后加载测试数据,并使用模型预测测试数据的标签。最后,将真实标签`true_labels`和预测标签`predicted_labels`传递给`confusion_matrix`函数,以计算混淆矩阵。其中,`labels`参数用于指定类别的顺序。最后打印混淆矩阵即可。
阅读全文
相关推荐

















