代码实现混淆矩阵可视化
时间: 2023-08-06 21:05:43 浏览: 100
机器学习 评估指标 之混淆矩阵以及可视化
混淆矩阵是一种用于评估分类模型性能的工具,它可以显示实际类别与预测类别之间的关系。在Python中,我们可以使用Scikit-learn库的confusion_matrix函数来计算混淆矩阵。然后,我们可以使用Matplotlib库的imshow函数将混淆矩阵可视化。
下面是一个简单的例子,展示如何使用Scikit-learn和Matplotlib库来实现混淆矩阵的可视化:
```python
# 导入必要的库
from sklearn.metrics import confusion_matrix
import matplotlib.pyplot as plt
import numpy as np
# 生成随机数据
y_true = np.random.randint(0, 2, size=100)
y_pred = np.random.randint(0, 2, size=100)
# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)
# 可视化混淆矩阵
plt.imshow(cm, cmap=plt.cm.Blues)
# 添加标题、轴标签和刻度
plt.title('Confusion matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')
tick_marks = np.arange(2)
plt.xticks(tick_marks, ['Class 0', 'Class 1'])
plt.yticks(tick_marks, ['Class 0', 'Class 1'])
# 添加文本标签
thresh = cm.max() / 2.
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
plt.text(j, i, format(cm[i, j], 'd'),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
# 显示图像
plt.show()
```
在这个例子中,我们首先生成了100个随机的实际类别和预测类别,并使用Scikit-learn库的confusion_matrix函数计算混淆矩阵。然后,我们使用Matplotlib库的imshow函数将混淆矩阵可视化。我们还添加了标题、轴标签和刻度,并使用文本标签在图像中显示混淆矩阵中的数字。最后,我们使用show函数显示图像。
这里的关键在于使用imshow函数将混淆矩阵可视化,以及使用text函数在图像中添加文本标签。你可以根据需要调整图像大小、颜色映射、文本标签样式等等。
阅读全文