多分类混淆矩阵热力图
时间: 2024-12-29 13:22:51 浏览: 10
### 创建和解读多分类问题中的混淆矩阵热力图
对于多分类问题,创建并解读混淆矩阵热力图有助于直观了解模型性能。通过可视化工具可以更清晰地展示不同类别之间的预测准确性。
#### 使用Python绘制混淆矩阵热力图
为了更好地理解和解释混淆矩阵,在实际应用中通常会借助`seaborn`库来绘制热力图。下面是一个完整的例子:
```python
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np
def plot_confusion_matrix(y_true, y_pred, classes,
normalize=False,
title='Confusion matrix',
cmap=plt.cm.Blues):
"""
This function prints and plots the confusion matrix.
Normalization can be applied by setting `normalize=True`.
"""
cm = confusion_matrix(y_true, y_pred)
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(cm, interpolation='nearest', cmap=cmap)
ax.figure.colorbar(im, ax=ax)
# We want to show all ticks...
ax.set(xticks=np.arange(cm.shape[1]),
yticks=np.arange(cm.shape[0]),
# ...and label them with the respective list entries
xticklabels=classes, yticklabels=classes,
title=title,
ylabel='True label',
xlabel='Predicted label')
# Rotate the tick labels and set their alignment.
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# Loop over data dimensions and create text annotations.
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
ax.text(j, i, format(cm[i, j], fmt),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
fig.tight_layout()
return ax
```
此函数接受真实标签`y_true`、预测标签`y_pred`以及类别名称列表`classes`作为输入参数,并可选地设置是否对数据进行标准化处理以便于百分比形式查看[^2]。
#### 解读混淆矩阵热力图
- **行**代表真实的类别;
- **列**代表由模型预测出来的类别;
- 如果某个单元格的颜色较深,则意味着该位置上的数值较大,表明对应的真实类被误分为另一类的情况较多;反之亦然。
当观察到某些特定区域颜色异常浓重时,这可能暗示着存在一些难以区分的相似特征使得这些样本容易被判错。此时可以通过调整模型结构或增加更多训练样本来改善这种情况[^1]。
阅读全文