Python Matplotlib实现混淆矩阵可视化与规范化

需积分: 5 7 下载量 191 浏览量 更新于2024-08-03 收藏 2KB TXT 举报
在Python编程中,混淆矩阵是一种重要的评估模型性能的工具,尤其是在分类问题中,如机器学习中的二分类或多分类任务。混淆矩阵展示了真实类别(实际标签)与预测类别(预测标签)之间的对比,有助于理解模型在不同类别上的表现,比如精确度、召回率和F1分数等。本文主要介绍如何使用matplotlib库在Python中绘制混淆矩阵。 首先,导入必要的库,包括matplotlib和numpy,因为混淆矩阵通常基于numpy数组计算: ```python import matplotlib.pyplot as plt import numpy as np ``` 函数`plot_Matrix`是核心部分,它接受四个参数:混淆矩阵`cm`、类别标签`classes`、图片标题`title`和颜色映射`cmap`。这个函数的主要步骤包括: 1. **标准化混淆矩阵**:由于混淆矩阵通常是按行进行归一化,即将每一行的总和设为1,这样可以更好地展示每个类别的相对性能。这里通过将矩阵元素除以各自行的和来实现。 2. **字符串格式化**:将归一化的混淆矩阵转换为字符串列表,便于后续打印和可视化。 3. **处理小数精度**:对于占比小于1%的单元格,设置为0,这有助于在颜色映射中避免过于精细的区分。 4. **创建图像**:使用`plt.subplots()`创建一个新的图形,并使用`imshow()`函数绘制混淆矩阵,指定颜色映射`cmap`和插值方法。 5. **设置坐标轴标签**:设置x轴和y轴的标签,分别为预测类别和实际类别,以及图片的标题。 6. **添加网格线**:使用`grid()`函数添加网格线,同时调整网格线的样式、颜色和宽度。对次要刻度线进行调整,使它们不显示标签,仅用于视觉分隔。 7. **旋转标签**:为了清晰地显示长标签,将x轴的标签旋转45度。 示例代码如下: ```python def plot_Matrix(cm, classes, title=None, cmap=plt.cm.Blues): # ...(上述步骤) for i in range(cm.shape[0]): for j in range(cm.shape[1]): if int(cm[i, j] * 100 + 0.5) == 0: cm[i, j] = 0 fig, ax = plt.subplots() im = ax.imshow(cm, interpolation='nearest', cmap=cmap) # ...(继续设置其他属性) plt.show() # 使用示例 # 假设你已经有了混淆矩阵cm和类别列表classes cm = ... # 混淆矩阵数值 classes = ... # 类别标签 plot_Matrix(cm, classes, 'Confusion Matrix Example', cmap=plt.cm.Blues) ``` 通过这个函数,你可以快速生成一个具有美观颜色映射和清晰标签的混淆矩阵图,这对于理解和分析模型性能非常有帮助。此外,如果你需要在不同的项目或场景中重复使用这个功能,可以将其封装成一个类或者模块,方便重用和定制。