混淆矩阵python代码实现
时间: 2023-08-31 14:09:41 浏览: 72
好的,以下是混淆矩阵的 Python 代码实现:
```python
# 导入相关库
from sklearn.metrics import confusion_matrix
import numpy as np
import matplotlib.pyplot as plt
# 生成样本数据,假设有2类(0和1),每类都有50个样本
y_true = np.concatenate([np.zeros(50), np.ones(50)])
y_pred = np.concatenate([np.zeros(25), np.ones(25), np.zeros(25), np.ones(25)])
# 计算混淆矩阵
cm = confusion_matrix(y_true, y_pred)
# 可视化混淆矩阵
plt.imshow(cm, cmap=plt.cm.Blues)
# 添加图例
classes = ['Class 0', 'Class 1']
tick_marks = np.arange(len(classes))
plt.xticks(tick_marks, classes)
plt.yticks(tick_marks, classes)
# 添加文本细节
thresh = cm.max() / 2.
for i in range(cm.shape[0]):
for j in range(cm.shape[1]):
plt.text(j, i, format(cm[i, j], 'd'), ha="center", va="center",
color="white" if cm[i, j] > thresh else "black")
# 添加其他细节
plt.title("Confusion Matrix")
plt.xlabel("Predicted Label")
plt.ylabel("True Label")
plt.colorbar()
plt.show()
```
以上代码生成一个大小为 $2\times2$ 的混淆矩阵,用于评估二分类模型的预测性能。在实际使用中,需要根据任务复杂度和实验需求选择并使用不同大小的混淆矩阵。
阅读全文