python绘制混淆矩阵热力图
时间: 2024-04-01 12:30:20 浏览: 162
在Python中,可以使用matplotlib库和seaborn库来绘制混淆矩阵的热力图。下面是一个简单的示例代码:
```python
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
# 创建混淆矩阵数据
confusion_matrix = np.array([[50, 10, 5],
[5, 40, 15],
[10, 5, 45]])
# 绘制热力图
plt.figure(figsize=(8, 6))
sns.heatmap(confusion_matrix, annot=True, cmap="YlGnBu", fmt="d")
plt.xlabel("Predicted label")
plt.ylabel("True label")
plt.title("Confusion Matrix Heatmap")
plt.show()
```
在这个示例中,我们首先创建了一个3x3的混淆矩阵数据。然后使用`sns.heatmap()`函数绘制热力图,其中`annot=True`表示在每个格子中显示数值,`cmap="YlGnBu"`表示使用蓝绿色调色板,`fmt="d"`表示显示整数类型的数值。
你可以根据自己的混淆矩阵数据进行相应的修改和调整,以满足你的需求。
相关问题
python 绘制混淆矩阵
要使用Python绘制混淆矩阵,你可以使用混淆矩阵的数据来创建热力图。下面是一个使用Matplotlib库绘制混淆矩阵的示例代码:
```python
import numpy as np
import matplotlib.pyplot as plt
# 定义混淆矩阵数据
confusion_matrix = np.array([[100, 10, 0],
[5, 90, 5],
[0, 8, 92]])
# 计算各类别的总数
class_totals = confusion_matrix.sum(axis=1)
# 计算各类别的准确率
class_accuracy = confusion_matrix / class_totals[:, np.newaxis]
# 设置标签
labels = ['Class A', 'Class B', 'Class C']
# 创建热力图
fig, ax = plt.subplots()
im = ax.imshow(class_accuracy, cmap='Blues')
# 设置颜色条
cbar = ax.figure.colorbar(im, ax=ax)
# 设置坐标轴标签
ax.set(xticks=np.arange(class_accuracy.shape[1]),
yticks=np.arange(class_accuracy.shape[0]),
xticklabels=labels, yticklabels=labels,
title='Confusion Matrix',
ylabel='True label',
xlabel='Predicted label')
# 在热力图中显示数值
for i in range(class_accuracy.shape[0]):
for j in range(class_accuracy.shape[1]):
ax.text(j, i, format(class_accuracy[i, j], '.2f'),
ha="center", va="center", color="white")
# 确保标签不被裁剪
plt.tight_layout()
# 显示图形
plt.show()
```
你可以根据你实际的混淆矩阵数据进行修改。这段代码会生成一个热力图,其中每个单元格表示分类器在预测时的准确率。标签显示了真实标签和预测标签的类别。
混淆矩阵热力图python
混淆矩阵热力图可以用 Python 中的 Matplotlib 和 Seaborn 库来实现。具体步骤如下:
1. 导入所需库:
```python
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
```
2. 计算混淆矩阵:
```python
y_true = [0, 1, 0, 1, 0, 1, 0, 0, 1, 1]
y_pred = [0, 0, 0, 1, 0, 1, 1, 0, 1, 0]
cm = confusion_matrix(y_true, y_pred)
```
3. 绘制热力图:
```python
sns.heatmap(cm, annot=True, cmap='Blues')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()
```
完整代码如下:
```python
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
# Example data
y_true = [0, 1, 0, 1, 0, 1, 0, 0, 1, 1]
y_pred = [0, 0, 0, 1, 0, 1, 1, 0, 1, 0]
# Calculate confusion matrix
cm = confusion_matrix(y_true, y_pred)
# Plot heatmap
sns.heatmap(cm, annot=True, cmap='Blues')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()
```
运行上述代码,即可得到混淆矩阵热力图。其中,热力图中的数字表示预测正确或预测错误的样本数,越大的数字颜色越深。同时,横轴表示预测结果,纵轴表示真实标签。可以根据热力图判断分类模型的性能。
阅读全文