通过调用metrics.confusion_matrix来混淆矩阵
时间: 2024-05-08 07:21:35 浏览: 116
confusion_matrix_混淆矩阵_源码
5星 · 资源好评率100%
混淆矩阵是用于评估分类模型性能的一种工具,它可以展示模型在测试集上真实标签和预测标签之间的关系。在 Python 中,可以通过调用 scikit-learn 库的 metrics 模块中的 confusion_matrix 函数来生成混淆矩阵。以下是一个示例代码:
``` python
from sklearn.metrics import confusion_matrix
y_true = [0, 1, 0, 1, 0, 1]
y_pred = [0, 0, 1, 1, 0, 1]
cm = confusion_matrix(y_true, y_pred)
print(cm)
```
在这个例子中,`y_true` 和 `y_pred` 分别代表了真实标签和预测标签。调用 `confusion_matrix(y_true, y_pred)` 函数会返回一个混淆矩阵,打印结果如下:
```
[[3 0]
[1 2]]
```
这个混淆矩阵的行表示真实标签,列表示预测标签,矩阵中的每个元素表示真实标签和预测标签相同的样本数。例如,第一行第一列的 3 表示真实标签为 0,但模型预测为 0 的样本数有 3 个;第二行第一列的 1 表示真实标签为 1,但模型预测为 0 的样本数有 1 个。通过分析混淆矩阵,我们可以得到模型在不同类别上的预测准确率和误判情况等信息。
阅读全文