混淆矩阵热力图python
时间: 2023-07-01 15:07:37 浏览: 356
详解使用python绘制混淆矩阵(confusion_matrix)
3星 · 编辑精心推荐
混淆矩阵热力图可以用 Python 中的 Matplotlib 和 Seaborn 库来实现。具体步骤如下:
1. 导入所需库:
```python
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
```
2. 计算混淆矩阵:
```python
y_true = [0, 1, 0, 1, 0, 1, 0, 0, 1, 1]
y_pred = [0, 0, 0, 1, 0, 1, 1, 0, 1, 0]
cm = confusion_matrix(y_true, y_pred)
```
3. 绘制热力图:
```python
sns.heatmap(cm, annot=True, cmap='Blues')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()
```
完整代码如下:
```python
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
# Example data
y_true = [0, 1, 0, 1, 0, 1, 0, 0, 1, 1]
y_pred = [0, 0, 0, 1, 0, 1, 1, 0, 1, 0]
# Calculate confusion matrix
cm = confusion_matrix(y_true, y_pred)
# Plot heatmap
sns.heatmap(cm, annot=True, cmap='Blues')
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.show()
```
运行上述代码,即可得到混淆矩阵热力图。其中,热力图中的数字表示预测正确或预测错误的样本数,越大的数字颜色越深。同时,横轴表示预测结果,纵轴表示真实标签。可以根据热力图判断分类模型的性能。
阅读全文