import itertools def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix(DNN)', cmap=plt.cm.Blues): """ This function prints and plots the confusion matrix. Normalization can be applied by setting `normalize=True`. """ plt.imshow(cm, interpolation='nearest', cmap=cmap) plt.title(title) plt.colorbar() tick_marks = np.arange(len(classes)) plt.xticks(tick_marks, classes, rotation=45) plt.yticks(tick_marks, classes) if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] thresh = cm.max() / 2. for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): plt.text(j, i, cm[i, j], horizontalalignment="center", color="white" if cm[i, j] > thresh else "black") plt.tight_layout() plt.ylabel('True label') plt.xlabel('Predicted label')
时间: 2024-01-04 17:02:45 浏览: 133
python中itertools模块zip_longest函数详解
这段代码定义了一个绘制混淆矩阵(Confusion Matrix)图像的函数,主要包括以下参数:
- `cm`:混淆矩阵,是一个二维数组,其中每个元素表示模型预测结果的正确与否。
- `classes`:类别名称,是一个列表,其中每个元素表示一个类别的名称。
- `normalize`:是否对混淆矩阵进行归一化,默认为 False。
- `title`:图像标题,默认为 'Confusion matrix(DNN)'。
- `cmap`:颜色映射表,用于表示混淆矩阵的颜色,默认为 plt.cm.Blues。
函数的主体部分使用 Matplotlib 库绘制混淆矩阵图像。具体来说,首先使用 `imshow()` 函数在图像上绘制混淆矩阵,然后添加标题和颜色条。接着在 x 轴和 y 轴上标注类别名称,并根据 `normalize` 参数选择是否对混淆矩阵进行归一化处理。最后,使用 `text()` 函数在每个混淆矩阵元素的中心位置添加文本标签,以显示该元素的值。
最后,使用 `tight_layout()` 函数调整图像布局,然后在 x 轴和 y 轴上添加标签,分别表示预测标签和真实标签。
阅读全文