用python画混淆矩阵
时间: 2023-10-24 22:07:24 浏览: 82
要绘制混淆矩阵,我们可以使用Python中的Matplotlib库和Scikit-learn库。
首先,需要导入所需的库和数据。以下是一个例子代码:
```python
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
y_true = [0, 1, 2, 0, 1, 2, 0, 1, 2] # 真实标签
y_pred = [0, 1, 1, 0, 2, 2, 1, 1, 2] # 预测标签
labels = ['Class 0', 'Class 1', 'Class 2'] # 类别标签
```
接下来,我们使用`confusion_matrix`函数计算混淆矩阵,并使用Matplotlib库绘制热力图:
```python
cm = confusion_matrix(y_true, y_pred)
plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(len(labels))
plt.xticks(tick_marks, labels, rotation=45)
plt.yticks(tick_marks, labels)
plt.tight_layout()
plt.xlabel('True label')
plt.ylabel('Predicted label')
plt.show()
```
这将生成以下混淆矩阵可视化结果:
![confusion_matrix](https://i.imgur.com/4mOlrqW.png)
希望这个例子可以帮助你绘制混淆矩阵。
阅读全文