python混淆矩阵可视化的代码
时间: 2023-08-14 21:55:57 浏览: 94
以下是一个Python代码示例,用于可视化混淆矩阵:
```python
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
def plot_confusion_matrix(confusion_matrix, labels):
plt.figure(figsize=(len(labels)+2, len(labels)))
ax = plt.subplot()
sns.heatmap(confusion_matrix, annot=True, fmt="d", cmap="Blues", cbar=False)
ax.set_xlabel('Predicted labels')
ax.set_ylabel('True labels')
ax.set_xticklabels(labels)
ax.set_yticklabels(labels)
plt.title('Confusion Matrix')
plt.show()
# 示例数据
labels = ['Class 0', 'Class 1', 'Class 2']
confusion_matrix = np.array([[10, 2, 3],
[1, 12, 4],
[5, 6, 15]])
plot_confusion_matrix(confusion_matrix, labels)
```
这段代码使用了`numpy`库用于创建混淆矩阵数据,`matplotlib`库用于绘制图表,以及`seaborn`库用于设置热力图样式。你可以根据自己的需求修改标签和混淆矩阵数据。运行代码后,将会显示一个带有注释数字的热力图,表示混淆矩阵。
阅读全文