混淆矩阵图python
时间: 2024-01-19 18:00:37 浏览: 115
object_detection_confusion_matrix:Python类,用于计算对象检测任务的混淆矩阵
混淆矩阵图是一种展示分类模型性能的可视化工具,用于比较有监督学习模型在多类别任务中的分类结果。在Python中,可以使用各种数据科学和机器学习库来生成和绘制混淆矩阵图,如matplotlib和seaborn。
首先,需要导入所需的库,并将真实标签和预测标签作为输入数据。可以使用sklearn库中的metrics模块来计算混淆矩阵。
```Python
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn import metrics
# 真实标签和预测标签
y_true = [0, 1, 0, 1, 2, 0, 2, 2]
y_pred = [0, 1, 0, 2, 1, 0, 2, 1]
# 计算混淆矩阵
confusion_matrix = metrics.confusion_matrix(y_true, y_pred)
```
接下来,可以使用seaborn库的heatmap函数将混淆矩阵可视化。
```Python
# 绘制混淆矩阵图
plt.figure(figsize=(8, 6))
sns.heatmap(confusion_matrix, annot=True, cmap="Blues")
# 添加轴标签
plt.xlabel("预测标签")
plt.ylabel("真实标签")
plt.title("混淆矩阵图")
# 显示图形
plt.show()
```
以上代码将生成一个具有颜色编码单元格和数字注释的矩形热图。行表示真实标签,列表示预测标签。热图的纵向表示了预测标签的准确性和误判情况,横向表示了真实标签的分布情况。注释数字表示每个类别的样本数。根据颜色的深浅可以观察到每个类别的分类情况,以及模型的整体性能。
通过混淆矩阵图,我们可以直观地了解分类模型在每个类别上的表现,并进一步优化模型以提高分类性能。
阅读全文