pytorch 如何画出混淆矩阵的图
时间: 2023-03-29 17:02:13 浏览: 124
06 绘制混淆矩阵 python
5星 · 资源好评率100%
可以使用sklearn.metrics库中的confusion_matrix函数来计算混淆矩阵,然后使用matplotlib库中的imshow函数来绘制混淆矩阵的图像。具体实现可以参考以下代码:
```python
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
# 生成随机的真实标签和预测标签
y_true = np.random.randint(, 10, size=100)
y_pred = np.random.randint(, 10, size=100)
# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)
# 绘制混淆矩阵图像
plt.imshow(cm, cmap=plt.cm.Blues)
plt.colorbar()
plt.xlabel('Predicted labels')
plt.ylabel('True labels')
plt.xticks(np.arange(10))
plt.yticks(np.arange(10))
plt.title('Confusion matrix')
plt.show()
```
这段代码会生成一个随机的混淆矩阵,并将其绘制成图像。你可以根据自己的需求修改代码中的数据和参数。
阅读全文