给我一段做混淆矩阵的代码跟能画图的
时间: 2024-02-24 10:56:39 浏览: 27
当您使用机器学习算法进行分类任务时,混淆矩阵是一种常用的评估指标。以下是一个 Python 代码示例,用于生成混淆矩阵并将其可视化:
```python
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import numpy as np
# 预测值和真实标签
y_pred = np.array([1, 0, 0, 1, 1, 0, 1])
y_true = np.array([1, 1, 0, 0, 1, 0, 1])
# 生成混淆矩阵
cm = confusion_matrix(y_true, y_pred)
# 创建热图
sns.set(font_scale=1.4) # 设置字体大小
sns.heatmap(cm, annot=True, annot_kws={"size": 16}, cmap='Blues', fmt='g')
# 添加标签和标题
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.title('Confusion Matrix')
# 显示图形
plt.show()
```
在这个例子中,我们有 7 个样本,其中 4 个样本被正确分类,3 个样本被错误分类。混淆矩阵显示了真实标签和预测标签之间的关系,其中每行代表真实标签,每列代表预测标签。颜色越深表示预测结果越准确。
相关问题
带数值的混淆矩阵画图代码
以下是Python代码,使用Matplotlib库绘制带数值的混淆矩阵图:
```python
import numpy as np
import matplotlib.pyplot as plt
# 定义混淆矩阵
confusion_matrix = np.array([[25, 5, 0], [5, 20, 5], [0, 5, 25]])
# 定义标签
labels = ['Class 0', 'Class 1', 'Class 2']
# 绘制混淆矩阵图
fig, ax = plt.subplots()
im = ax.imshow(confusion_matrix, cmap='Blues')
# 显示数值
for i in range(len(labels)):
for j in range(len(labels)):
text = ax.text(j, i, confusion_matrix[i, j],
ha="center", va="center", color="w")
# 设置标题、轴标签、刻度标签
ax.set_title("Confusion Matrix")
ax.set_xticks(np.arange(len(labels)))
ax.set_yticks(np.arange(len(labels)))
ax.set_xticklabels(labels)
ax.set_yticklabels(labels)
ax.set_xlabel("Predicted label")
ax.set_ylabel("True label")
# 旋转刻度标签以避免重叠
plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
rotation_mode="anchor")
# 在混淆矩阵图旁添加颜色条
cbar = ax.figure.colorbar(im, ax=ax)
# 显示图形
plt.show()
```
运行以上代码,将得到带数值的混淆矩阵图。可以根据实际需要修改混淆矩阵、标签、标题和颜色映射等参数。
读取两个列表完成带数值的混淆矩阵画图代码
以下是Python代码:
```python
import matplotlib.pyplot as plt
import numpy as np
# 两个列表,分别表示预测值和真实值
predicted = [1, 0, 1, 1, 0, 1, 1, 0, 1, 0]
actual = [1, 0, 0, 1, 1, 1, 0, 0, 1, 0]
# 计算混淆矩阵
confusion_matrix = np.zeros((2, 2))
for p, a in zip(predicted, actual):
confusion_matrix[p][a] += 1
# 画图
fig, ax = plt.subplots(figsize=(5, 5))
ax.matshow(confusion_matrix, cmap=plt.cm.Blues)
for i in range(confusion_matrix.shape[0]):
for j in range(confusion_matrix.shape[1]):
ax.text(x=j, y=i, s=int(confusion_matrix[i, j]), va='center', ha='center')
ax.set_xlabel('True label')
ax.set_ylabel('Predicted label')
plt.show()
```
解释一下代码的过程:
- 首先定义了两个列表 `predicted` 和 `actual`,分别表示模型的预测结果和真实结果。
- 接着使用 NumPy 库创建了一个 2x2 的全零数组 `confusion_matrix`,用于记录混淆矩阵。
- 遍历 `predicted` 和 `actual` 列表,根据预测结果和真实结果更新混淆矩阵。
- 最后使用 Matplotlib 库画出混淆矩阵的图像,其中 `ax.matshow()` 函数用于显示矩阵,`ax.text()` 函数用于在每个单元格中添加数值标签。
相关推荐
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![rar](https://img-home.csdnimg.cn/images/20210720083606.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)