confusion_matrix返回得到什么
时间: 2024-01-27 21:05:49 浏览: 25
`confusion_matrix`是一个用于计算分类模型的混淆矩阵的函数。混淆矩阵是一个表格,用于比较分类模型的预测结果和真实结果。它可以帮助我们了解模型的性能,包括准确率、召回率、F1得分等指标。`confusion_matrix`函数返回一个数组,其中行表示真实标签,列表示预测标签。数组中的每个元素表示真实标签和预测标签的组合数量。例如,数组中的第一行第二列表示真实标签为第一类,但被预测为第二类的样本数量。下面是一个使用`confusion_matrix`函数的例子:
```python
from sklearn.metrics import confusion_matrix
y_true = [0, 1, 0, 1, 1, 1]
y_pred = [1, 1, 0, 1, 0, 1]
cm = confusion_matrix(y_true, y_pred)
print(cm)
```
输出结果为:
```
array([[1, 2],
[1, 2]])
```
这个混淆矩阵表示真实标签为0的样本有1个被预测为0,2个被预测为1;真实标签为1的样本有1个被预测为0,2个被预测为1。
相关问题
为什么我使用confusion_matrix获得的结果是高维数组
`confusion_matrix` 函数返回的结果是一个多维数组,其中每个元素表示实际类别和预测类别的组合数量。如果你的分类问题有多个类别,那么就会生成一个高维数组。例如,对于一个有三个类别的分类问题,`confusion_matrix` 函数返回一个 3x3 的矩阵。如果你的分类问题有更多的类别,那么它将返回一个更高维度的数组。
如果你想查看每个类别的分类报告,可以使用 `classification_report` 函数。它将返回每个类别的精确度、召回率、F1分数和支持数等指标。如果你的分类问题有多个类别,那么它将返回一个包含每个类别指标的字典。
以下是一个使用 `confusion_matrix` 和 `classification_report` 函数的例子:
```python
from sklearn.metrics import confusion_matrix, classification_report
y_true = [0, 1, 2, 0, 1, 2]
y_pred = [0, 1, 1, 0, 2, 1]
cm = confusion_matrix(y_true, y_pred)
print(cm)
# 输出结果
# [[2 0 0]
# [0 1 1]
# [0 1 1]]
report = classification_report(y_true, y_pred)
print(report)
# 输出结果
# precision recall f1-score support
#
# 0 1.00 1.00 1.00 2
# 1 0.50 0.50 0.50 2
# 2 0.50 0.50 0.50 2
#
# accuracy 0.67 6
# macro avg 0.67 0.67 0.67 6
# weighted avg 0.67 0.67 0.67 6
```
在上面的例子中,我们使用 `confusion_matrix` 函数计算了真实标签和预测标签的混淆矩阵。结果是一个 3x3 的矩阵,其中每个元素表示实际类别和预测类别的组合数量。我们还使用 `classification_report` 函数计算了每个类别的分类报告。结果显示了每个类别的精确度、召回率、F1分数和支持数等指标。
通过调用metrics.confusion_matrix来混淆矩阵
混淆矩阵是用于评估分类模型性能的一种工具,它可以展示模型在测试集上真实标签和预测标签之间的关系。在 Python 中,可以通过调用 scikit-learn 库的 metrics 模块中的 confusion_matrix 函数来生成混淆矩阵。以下是一个示例代码:
``` python
from sklearn.metrics import confusion_matrix
y_true = [0, 1, 0, 1, 0, 1]
y_pred = [0, 0, 1, 1, 0, 1]
cm = confusion_matrix(y_true, y_pred)
print(cm)
```
在这个例子中,`y_true` 和 `y_pred` 分别代表了真实标签和预测标签。调用 `confusion_matrix(y_true, y_pred)` 函数会返回一个混淆矩阵,打印结果如下:
```
[[3 0]
[1 2]]
```
这个混淆矩阵的行表示真实标签,列表示预测标签,矩阵中的每个元素表示真实标签和预测标签相同的样本数。例如,第一行第一列的 3 表示真实标签为 0,但模型预测为 0 的样本数有 3 个;第二行第一列的 1 表示真实标签为 1,但模型预测为 0 的样本数有 1 个。通过分析混淆矩阵,我们可以得到模型在不同类别上的预测准确率和误判情况等信息。