三分类混淆矩阵怎么计算
时间: 2025-01-03 22:29:32 浏览: 3
### 如何计算三分类问题的混淆矩阵
对于多类分类问题中的三个类别,构建混淆矩阵的过程与二元分类相似,只是规模更大。假设存在三个类别:A、B 和 C。
#### 定义真实值和预测值
为了创建一个三类别的混淆矩阵,需要记录每个样本的真实标签以及模型对该样本的预测标签。这可以通过如下方式实现:
- **True Positives (TP)**: 对于每一个类别而言,在该类别下的实际正例被正确识别的数量。
- **False Positives (FP)**: 被错误地标记为此类别的其他类别的实例数。
- **False Negatives (FN)**: 属于此类别但未被正确检测到的实例数量。
- **True Negatives (TN)**: 不属于当前考虑的特定类别且也被正确判断不属于此类目的所有情况总数[^1]。
然而,在一个多类别设置下,通常不会单独报告 TN 值,因为当涉及到多个类别时,定义哪些观测是非目标变得复杂。因此,主要关注的是 TP、FP 及 FN 的统计量。
#### 构建混淆矩阵表格
给定上述定义,可以建立一个 3×3 的方阵来表示不同类别之间的关系。每一行代表真实的类别分布,而列则对应着预测的结果。具体形式如下表所示:
| | Predicted Class A | Predicted Class B | Predicted Class C |
|--|-------------------|-------------------|
| Actual Class A| TPA | FPA | FPA |
| Actual Class B| FPB | TPB | FPB |
| Actual Class C| FCC | FCC | CPC |
其中:
- `TPX` 表示真正例计数(即实际为 X 类并成功预测为 X)
- `FPY` 表示假正例计数(即实际上不是 Y 类却被误认为是 Y)
通过这种方式,能够直观地看到各类别之间相互影响的程度,并据此评估模型性能[^4]。
```python
from sklearn.metrics import confusion_matrix
import numpy as np
# Example true labels and predictions for demonstration purposes only.
y_true = ["cat", "dog", "rabbit", "cat", "dog", "rabbit"]
y_pred = ["cat", "dog", "cat", "rabbit", "dog", "rabbit"]
labels = ['cat', 'dog', 'rabbit']
cm = confusion_matrix(y_true, y_pred, labels=labels)
print(cm)
```
此代码片段展示了如何利用 Python 中的 scikit-learn 库快速生成混淆矩阵。这里使用的数据集仅为演示用途;实践中应替换为自己的测试数据集。
阅读全文