pytorch怎样绘制有五种类别数据的混淆矩阵
时间: 2023-12-03 12:43:35 浏览: 68
下面是使用PyTorch绘制混淆矩阵的一个例子,它可以处理五种类别的数据。
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
# 生成随机的真实标签和预测标签
y_true = np.random.randint(low=0, high=5, size=100) # 生成100个真实标签,范围是[0,4]
y_pred = np.random.randint(low=0, high=5, size=100) # 生成100个预测标签,范围是[0,4]
# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)
# 绘制混淆矩阵
fig, ax = plt.subplots()
im = ax.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
ax.figure.colorbar(im, ax=ax)
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
xticklabels=['class 0', 'class 1', 'class 2', 'class 3', 'class 4'], # 五个类别的标签
yticklabels=['class 0', 'class 1', 'class 2', 'class 3', 'class 4'],
title='Confusion matrix',
ylabel='True label',
xlabel='Predicted label')
# 在每个单元格中添加数值
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], 'd'),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
plt.show()
```
在上面的代码中,我们首先生成了100个随机的真实标签和预测标签,范围是[0,4]。然后使用`confusion_matrix`函数计算混淆矩阵。最后使用Matplotlib绘制混淆矩阵,其中每个单元格中的数值表示真实标签和预测标签匹配的数量。
注意,我们需要使用`xticklabels`和`yticklabels`参数指定每个类别的标签。如果你的数据类别不同,需要相应地修改这些标签。
阅读全文