sklearn如何输出混淆矩阵
时间: 2024-09-23 20:10:10 浏览: 33
使用Python和sklearn库生成混淆矩阵的实例
在scikit-learn库中,输出混淆矩阵通常用于评估分类模型的性能,它显示了真实类别和预测类别之间的对应情况。以下是生成混淆矩阵的基本步骤:
1. 首先,你需要有一个训练好的分类模型,并对测试数据进行预测。
```python
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import confusion_matrix
# 加载数据集
iris = datasets.load_iris()
X_train, X_test, y_train, y_test = train_test_split(iris.data, iris.target, test_size=0.2, random_state=42)
# 创建并训练模型
model = LogisticRegression()
model.fit(X_train, y_train)
# 对测试数据进行预测
y_pred = model.predict(X_test)
```
2. 然后,你可以使用`confusion_matrix`函数来计算混淆矩阵。
```python
# 计算混淆矩阵
cm = confusion_matrix(y_test, y_pred)
print(cm)
```
这将输出一个二维数组,其中行代表实际类别,列表示预测类别。矩阵的主要对角线表示正确的预测,非对角线则表示错误的预测。
阅读全文