请举例用代码实现产生图像分类的混淆矩阵
时间: 2023-12-04 07:30:27 浏览: 66
数据科学与机器学习:混淆矩阵深度解析-精准评估分类模型性能的必备工具
以下是一个简单的 Python 代码示例,用于生成图像分类的混淆矩阵。假设我们有一个包含 4 个类别的图像分类任务,数据集中共有 100 个样本,其中 20 个属于类别 0,30 个属于类别 1,40 个属于类别 2,10 个属于类别 3。我们使用一个简单的卷积神经网络模型对这些图像进行分类,并用混淆矩阵来评估模型的性能。
``` python
import numpy as np
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
# 模拟样本标签和预测标签
y_true = np.concatenate([np.zeros(20), np.ones(30), np.full(40, 2), np.full(10, 3)])
y_pred = np.concatenate([np.zeros(18), np.ones(32), np.full(39, 2), np.full(11, 3)])
# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)
# 可视化混淆矩阵
plt.imshow(cm, cmap='Blues')
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(len(np.unique(y_true)))
plt.xticks(tick_marks, np.unique(y_true))
plt.yticks(tick_marks, np.unique(y_true))
plt.xlabel('Predicted')
plt.ylabel('True')
plt.show()
```
运行上述代码后,我们可以得到一个直观的混淆矩阵图像,如下所示:
![image](https://user-images.githubusercontent.com/58663458/127425489-7d39c2c6-4f2d-4eaf-95a5-2a5c4141d618.png)
其中,横轴表示预测标签,纵轴表示真实标签。对角线上的数字代表分类正确的样本数,其他位置上的数字代表分类错误的样本数。例如,第一行第二列的数字 2 表示有 2 个属于类别 1 的样本被错误地分类为类别 0。
阅读全文