python 多分类混淆矩阵和auc的计算
时间: 2023-08-14 13:26:40 浏览: 146
多分类混淆矩阵是评估多分类模型性能的一种常见方法。在 Python 中可以使用 scikit-learn 库来计算多分类混淆矩阵和 AUC。
下面是一个示例代码:
```python
from sklearn.metrics import confusion_matrix, roc_auc_score
import numpy as np
# 预测结果
y_pred = np.array([[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[1, 0, 0],
[0, 0, 1],
[0, 1, 0]])
# 实际结果
y_true = np.array([0, 1, 2, 0, 2, 1])
# 计算混淆矩阵
cm = confusion_matrix(y_true, np.argmax(y_pred, axis=1))
print(cm)
# 计算 AUC
auc = roc_auc_score(y_true, y_pred, multi_class='ovr')
print(auc)
```
输出结果为:
```
[[2 0 0]
[0 1 1]
[0 1 1]]
0.8333333333333333
```
其中,混淆矩阵的行表示实际结果,列表示预测结果。矩阵中的每个元素表示对应分类的样本数。例如,左上角的 2 表示实际为 0 的样本中,被预测为 0 的有 2 个。
AUC 的计算需要指定 multi_class 参数,它可以取值为 'ovr' 或 'ovo'。'ovr' 表示采用一对多的方式计算 AUC,即将每个类别与其他所有类别合并为一个二分类问题,计算每个二分类问题的 AUC 并求平均值;'ovo' 表示采用一对一的方式计算 AUC,即将每两个类别之间都构造一个二分类问题,计算每个二分类问题的 AUC 并求平均值。在上面的示例代码中,我们指定了 multi_class='ovr',表示采用一对多的方式计算 AUC。
阅读全文