写一段画混淆矩阵的代码
时间: 2023-09-07 22:05:15 浏览: 114
### 回答1:
以下是一段用 Python 画混淆矩阵的代码示例:
```
import matplotlib.pyplot as plt
import numpy as np
# 假设您有如下混淆矩阵
cm = np.array([[100, 20], [30, 40]])
# 绘制混淆矩阵
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
# 添加标题
plt.title('Confusion Matrix')
# 添加数字标签
threshold = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
plt.text(j, i, cm[i, j],
horizontalalignment="center",
color="white" if cm[i, j] > threshold else "black")
# 添加 X 轴标签
plt.xlabel('Predicted label')
# 添加 Y 轴标签
plt.ylabel('True label')
# 添加图例
plt.colorbar()
# 去掉坐标轴的刻度
plt.xticks(np.arange(2), [0, 1])
plt.yticks(np.arange(2), [0, 1])
# 显示图像
plt.show()
```
### 回答2:
画混淆矩阵是用于衡量分类器性能的一种常用方法。以下是一个简单的Python代码示例,用于绘制混淆矩阵。
```python
import numpy as np
import matplotlib.pyplot as plt
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
"""
画混淆矩阵函数
参数:
cm -- 混淆矩阵
classes -- 分类器所包含的类别
normalize -- 是否对矩阵进行归一化,默认为False
title -- 图表标题,默认为'Confusion matrix'
cmap -- 颜色图谱,默认为plt.cm.Blues
"""
if normalize:
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] # 对混淆矩阵进行归一化处理
plt.figure(figsize=(8, 6))
plt.imshow(cm, interpolation='nearest', cmap=cmap)
plt.title(title)
plt.colorbar()
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes, rotation=45)
plt.yticks(tick_marks, classes)
fmt = '.2f' if normalize else 'd'
thresh = cm.max() / 2.
for i, j in np.ndindex(cm.shape):
plt.text(j, i, format(cm[i, j], fmt),
horizontalalignment="center",
color="white" if cm[i, j] > thresh else "black")
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.tight_layout()
# 示例用法
classes = ['A', 'B', 'C', 'D']
cm = np.array([[10, 2, 1, 0],
[3, 12, 1, 2],
[2, 1, 9, 1],
[0, 2, 0, 15]])
plot_confusion_matrix(cm, classes, normalize=False, title='Confusion Matrix')
plt.show()
```
上述代码中,`plot_confusion_matrix` 函数用于绘制混淆矩阵。`cm` 参数表示混淆矩阵,`classes` 参数表示分类器的类别,`normalize` 参数表示是否对混淆矩阵进行归一化处理。函数内部使用 `plt.imshow` 函数绘制矩阵,使用 `plt.xticks` 和 `plt.yticks` 函数设置坐标刻度,使用 `plt.text` 函数在图表中显示矩阵的数值。最后通过调用 `plt.show` 函数显示图表。
### 回答3:
画混淆矩阵是一种常见的数据可视化方法,用于衡量机器学习算法的性能。下面是一个简单的Python代码段,用于绘制一个2x2的混淆矩阵:
```
import matplotlib.pyplot as plt
import numpy as np
# 定义分类类别
labels = ['真实类别A', '真实类别B']
# 定义混淆矩阵的数据,每一行代表一个真实类别的样本数
matrix_data = np.array([[100, 20], # 预测为类别A的样本中,实际为类别A的个数为100,实际为类别B的个数为20
[30, 150]]) # 预测为类别B的样本中,实际为类别A的个数为30,实际为类别B的个数为150
# 绘制混淆矩阵
fig, ax = plt.subplots()
im = ax.imshow(matrix_data, interpolation='nearest', cmap=plt.cm.Blues)
# 添加颜色条
cbar = ax.figure.colorbar(im, ax=ax)
# 创建刻度标签
ax.set(xticks=np.arange(matrix_data.shape[1]),
yticks=np.arange(matrix_data.shape[0]),
xticklabels=labels, yticklabels=labels,
title='混淆矩阵',
ylabel='真实类别',
xlabel='预测类别')
# 在方块内显示具体数值
thresh = matrix_data.max() / 2.0
for i in range(matrix_data.shape[0]):
for j in range(matrix_data.shape[1]):
ax.text(j, i, format(matrix_data[i, j], 'd'),
ha="center", va="center",
color="white" if matrix_data[i, j] > thresh else "black")
# 调整布局
fig.tight_layout()
# 展示混淆矩阵
plt.show()
```
这段代码使用了`matplotlib`库来绘制混淆矩阵图像。首先,我们定义了分类的类别标签,然后创建一个包含混淆矩阵数据的2x2的NumPy数组。接下来,我们使用`imshow`函数绘制混淆矩阵的图像,并通过`colorbar`函数添加一个颜色条。最后,我们使用`set`函数设置刻度标签、标题和坐标轴标签。然后,我们使用两个嵌套的循环在每个方块内显示对应的具体数值。最后,通过调整布局和展示图像的方式展示混淆矩阵。
阅读全文