python绘制Confusion matrix 立体图的代码
时间: 2023-07-23 10:03:53 浏览: 115
以下是使用Matplotlib绘制Confusion Matrix立体图的示例代码:
```python
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
# 定义数据
conf_mat = np.array([[85, 10, 5], [10, 85, 5], [5, 10, 85]])
# 获取矩阵大小
rows, cols = conf_mat.shape
# 设置x、y、z轴的名称
names = ['Class1', 'Class2', 'Class3']
# 创建3D图形
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
# 遍历矩阵,绘制立方体
for x in range(rows):
for y in range(cols):
# 计算立方体的长、宽、高
dx = dy = dz = 1
dz = conf_mat[x][y] / np.sum(conf_mat) * 10
# 绘制立方体
ax.bar3d(x, y, 0, dx, dy, dz, color='b', alpha=0.5)
# 设置坐标轴名称和范围
ax.set_xlabel('Predicted labels')
ax.set_ylabel('True labels')
ax.set_zlabel('Count')
ax.set_xticks(range(rows))
ax.set_yticks(range(cols))
ax.set_xticklabels(names)
ax.set_yticklabels(names)
ax.set_xlim([0, rows])
ax.set_ylim([0, cols])
ax.set_zlim([0, np.max(conf_mat) / np.sum(conf_mat) * 10])
# 显示图形
plt.show()
```
运行代码后,将生成一个3D立体图,其中x、y轴表示预测标签和真实标签,z轴表示对应的计数。每个立方体的高度表示对应的计数,颜色表示透明度(越深越不透明)。
阅读全文