混淆矩阵heatmap图
时间: 2023-12-04 18:41:22 浏览: 169
混淆矩阵是用于评估分类模型性能的一种矩阵。在混淆矩阵中,行表示实际类别,列表示预测类别。对角线上的元素表示正确分类的样本数,非对角线上的元素表示错误分类的样本数。混淆矩阵heatmap图可以直观地展示混淆矩阵的结果。
以下是一个绘制混淆矩阵heatmap图的例子:
```python
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
# 构造混淆矩阵
confusion_matrix = np.array([[50, 10, 5], [5, 20, 15], [10, 5, 30]])
# 绘制heatmap图
fig, ax = plt.subplots(figsize=(8, 6))
sns.heatmap(confusion_matrix, annot=True, cmap='Blues', fmt='g', ax=ax)
# 设置图形属性
ax.set_xlabel('Predicted labels')
ax.set_ylabel('True labels')
ax.set_title('Confusion Matrix')
# 显示图形
plt.show()
```
在上述代码中,我们首先构造了一个3x3的混淆矩阵,然后使用seaborn库中的heatmap函数绘制了混淆矩阵的heatmap图。其中,annot=True表示在heatmap图中显示数值,cmap='Blues'表示使用蓝色调色板,fmt='g'表示使用普通数字格式显示数值。
相关问题
# 混淆矩阵 cm = metrics.confusion_matrix(y_test, sklearn_predict, labels = [0, 1]) print('混淆矩阵:\n', cm) # 混淆矩阵热力图 import seaborn as sns sns.heatmap(cm, annot=True, fmt = '.2e', cmap = 'GnBu')报错的原因
该报错可能是因为您没有导入所需的库,比如 `metrics` 和 `seaborn` 库。请确保您已经通过以下方式将它们导入到您的代码中:
```python
from sklearn import metrics
import seaborn as sns
```
如果已经导入了这些库,那么请检查您是否正确地使用了这些库的函数和参数。另外,如果您能提供更具体的错误信息,我将能够更好地帮助您解决问题。
sns.heatmap(confusion_mat, annot=True, cmap="Blues") plt.show()改成导出混淆矩阵图片文件
要将混淆矩阵保存为图片文件,您可以使用`matplotlib`库中的`savefig`函数。可以按照以下步骤进行操作:
1. 导入`matplotlib`库:
```python
import matplotlib.pyplot as plt
```
2. 绘制热力图并隐藏显示:
```python
plt.imshow(confusion_mat, cmap="Blues")
plt.colorbar()
plt.title("Confusion Matrix")
plt.xlabel("Predicted Labels")
plt.ylabel("True Labels")
plt.xticks(ticks=range(len(labels)), labels=labels, rotation=45)
plt.yticks(ticks=range(len(labels)), labels=labels)
plt.tight_layout()
plt.savefig("confusion_matrix.png")
```
3. 保存图片文件:
```python
plt.savefig("confusion_matrix.png")
```
以上代码将生成一个名为`confusion_matrix.png`的图片文件,保存在当前工作目录下。您可以将文件名更改为您喜欢的名称和路径。
阅读全文