python混淆矩阵画图颜色
时间: 2023-08-28 13:22:33 浏览: 277
在Python中,你可以使用`matplotlib`库来画混淆矩阵,并使用不同的颜色来表示不同的类别。下面是一个简单的例子:
```python
import numpy as np
import matplotlib.pyplot as plt
# 创建一个混淆矩阵
confusion_matrix = np.array([[10, 2, 3],
[4, 15, 6],
[7, 8, 20]])
# 设置类别标签
labels = ['Class 1', 'Class 2', 'Class 3']
# 设置颜色映射
colors = ['blue', 'green', 'red']
# 绘制混淆矩阵图
plt.figure(figsize=(8, 6))
plt.imshow(confusion_matrix, cmap='Blues') # 使用蓝色作为背景色
# 添加颜色条
cbar = plt.colorbar()
cbar.set_label('Count')
# 添加类别标签
相关问题
sklearn混淆矩阵画图
你可以使用 `sklearn.metrics` 模块中的 `confusion_matrix` 函数来计算混淆矩阵,然后使用 `matplotlib` 库中的 `imshow` 函数来绘制混淆矩阵的图像。下面是一个示例代码:
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
# 假设你有真实标签和预测标签
y_true = np.array([0, 1, 0, 1, 1, 0, 0, 1])
y_pred = np.array([0, 1, 0, 0, 1, 1, 0, 1])
# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)
# 绘制混淆矩阵图像
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
classes = ['Class 0', 'Class 1'] # 类别名称
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes)
plt.yticks(tick_marks, classes)
# 添加数据标签
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
plt.text(j, i, format(cm[i, j], 'd'),
ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()
```
这段代码会显示一个带有颜色编码的混淆矩阵图像,其中每个单元格的值表示预测为某一类别的样本数量。你可以根据自己的数据和类别名称进行相应的修改。希望对你有所帮助!
带数值的混淆矩阵画图代码
以下是Python代码,使用Matplotlib库绘制带数值的混淆矩阵图:
```python
import numpy as np
import matplotlib.pyplot as plt
# 定义混淆矩阵
confusion_matrix = np.array([[25, 5, 0], [5, 20, 5], [0, 5, 25]])
# 定义标签
labels = ['Class 0', 'Class 1', 'Class 2']
# 绘制混淆矩阵图
fig, ax = plt.subplots()
im = ax.imshow(confusion_matrix, cmap='Blues')
# 显示数值
for i in range(len(labels)):
for j in range(len(labels)):
text = ax.text(j, i, confusion_matrix[i, j],
ha="center", va="center", color="w")
# 设置标题、轴标签、刻度标签
ax.set_title("Confusion Matrix")
ax.set_xticks(np.arange(len(labels)))
ax.set_yticks(np.arange(len(labels)))
ax.set_xticklabels(labels)
ax.set_yticklabels(labels)
ax.set_xlabel("Predicted label")
ax.set_ylabel("True label")
# 旋转刻度标签以避免重叠
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# 在混淆矩阵图旁添加颜色条
cbar = ax.figure.colorbar(im, ax=ax)
# 显示图形
plt.show()
```
运行以上代码,将得到带数值的混淆矩阵图。可以根据实际需要修改混淆矩阵、标签、标题和颜色映射等参数。
阅读全文