如何可视化混淆矩阵,给出具体代码
时间: 2024-02-04 18:10:23 浏览: 66
可视化混淆矩阵可以帮助我们更直观地了解分类模型的性能表现。下面是一份可视化混淆矩阵的示例代码,使用了matplotlib库的imshow函数来绘制混淆矩阵的热力图:
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
# 示例标签和预测值
y_true = [0, 1, 1, 0, 1, 1, 0, 0, 1, 0]
y_pred = [0, 0, 1, 0, 1, 1, 0, 1, 1, 0]
# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)
# 绘制热力图
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.colorbar()
plt.xticks(np.arange(2), ['Negative', 'Positive'])
plt.yticks(np.arange(2), ['Negative', 'Positive'])
plt.xlabel('Predicted label')
plt.ylabel('True label')
plt.title('Confusion Matrix')
plt.show()
```
在这个示例中,我们使用了sklearn.metrics中的`confusion_matrix`函数计算混淆矩阵,然后使用matplotlib库的`imshow`函数在图像中绘制出混淆矩阵的热力图。可以看到,混淆矩阵的行表示真实标签,列表示预测标签,因此,真实标签为负样本,预测标签为负样本的数量为3,真实标签为正样本,预测标签为正样本的数量为4。
阅读全文