二分类的混淆矩阵python
时间: 2024-12-27 15:18:01 浏览: 6
### 计算和展示二分类模型的混淆矩阵
在Python中处理二分类问题时,`sklearn.metrics.confusion_matrix` 函数提供了便捷的方法来计算混淆矩阵。此函数接收真实标签和预测标签作为输入参数,并返回一个二维数组形式的混淆矩阵[^1]。
为了更好地理解和可视化这些结果,可以利用 `matplotlib.pyplot` 和 `seaborn` 库中的热图功能来图形化显示混淆矩阵的内容。这不仅有助于直观理解数据分布情况,还便于向非技术人员解释模型性能[^2]。
下面是一个完整的例子,展示了如何加载数据集、训练逻辑回归模型以及创建并绘制混淆矩阵:
```python
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
import matplotlib.pyplot as plt
# 创建模拟数据集
X, y = make_classification(n_samples=1000, n_features=20,
n_informative=2, n_redundant=10,
random_state=42)
# 划分训练集与测试集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
# 初始化并拟合逻辑回归模型
model = LogisticRegression()
model.fit(X_train, y_train)
# 预测类别
y_pred = model.predict(X_test)
# 构建混淆矩阵
cm = confusion_matrix(y_test, y_pred)
# 绘制混淆矩阵
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(cmap="Blues")
plt.show()
```
通过上述代码片段可以看出,除了简单的数值表示外,还可以进一步分析由混淆矩阵衍生出来的其他重要评价指标,比如准确率 (Accuracy)、精确度 (Precision)、召回率 (Recall/Sensitivity),F1分数等,这些都是衡量分类器表现的关键因素[^4]。
阅读全文