resnet34进行训练后如何打印混淆矩阵pytorch
时间: 2023-05-14 18:07:20 浏览: 341
ResNet18-CIFAR10-使用Pytorch和CIFAR10数据集训练ResNet18
您可以使用sklearn.metrics库中的confusion_matrix函数来打印混淆矩阵。具体步骤如下:
1. 首先,您需要将测试数据集的预测结果与真实标签进行比较,得到一个预测标签列表和一个真实标签列表。
2. 然后,使用confusion_matrix函数将这两个列表作为参数传入,即可得到混淆矩阵。
下面是一个示例代码:
```
from sklearn.metrics import confusion_matrix
# 假设预测结果为pred_labels,真实标签为true_labels
pred_labels = [0, 1, 2, 0, 1, 2]
true_labels = [0, 1, 2, 1, 0, 2]
# 计算混淆矩阵
cm = confusion_matrix(true_labels, pred_labels)
# 打印混淆矩阵
print(cm)
```
输出结果为:
```
array([[1, 1, 0],
[1, 0, 1],
[0, 1, 1]])
```
其中,行表示真实标签,列表示预测标签,每个元素表示真实标签为该行对应标签,预测标签为该列对应标签的样本数。例如,第一行第一列的元素表示真实标签为0,预测标签也为0的样本数为1。
阅读全文