针对图像五分类的混淆矩阵代码
时间: 2023-08-06 19:16:28 浏览: 101
好的,下面是一个针对图像五分类的混淆矩阵代码示例:
```python
import numpy as np
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
# 定义分类标签
classes = ['class1', 'class2', 'class3', 'class4', 'class5']
# 假设预测结果为一维数组preds,真实标签为一维数组labels
preds = np.array([1, 2, 4, 1, 3, 3, 1, 2, 2, 4, 5, 5, 4, 3, 2, 1, 3, 2, 5, 4])
labels = np.array([1, 2, 4, 1, 3, 5, 1, 2, 2, 4, 5, 5, 4, 3, 2, 1, 3, 2, 5, 4])
# 计算混淆矩阵
cm = confusion_matrix(labels, preds)
# 可视化混淆矩阵
fig, ax = plt.subplots()
im = ax.imshow(cm, cmap=plt.cm.Blues)
ax.set_xticks(np.arange(len(classes)))
ax.set_yticks(np.arange(len(classes)))
ax.set_xticklabels(classes)
ax.set_yticklabels(classes)
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
for i in range(len(classes)):
for j in range(len(classes)):
text = ax.text(j, i, cm[i, j],
ha="center", va="center", color="w")
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
fig.tight_layout()
plt.show()
```
此代码使用sklearn.metrics中的confusion_matrix方法计算混淆矩阵,并使用matplotlib.pyplot中的imshow方法可视化混淆矩阵。你可以根据自己的需求修改代码中的分类标签、预测结果和真实标签,以适应不同的图像分类任务。
阅读全文