python将多分类问题的混淆矩阵以二值分类混淆矩阵输出
时间: 2023-07-21 18:31:29 浏览: 92
Python 计算混淆矩阵
将多分类问题的混淆矩阵以二值分类混淆矩阵输出,可以使用 scikit-learn 库的 multilabel_confusion_matrix() 函数。该函数的参数包括真实标签和预测标签,示例如下:
```python
from sklearn.metrics import multilabel_confusion_matrix
y_true = [[0, 1, 1],
[1, 0, 1],
[0, 0, 1],
[1, 1, 0]]
y_pred = [[1, 0, 1],
[0, 0, 1],
[0, 1, 0],
[1, 0, 0]]
labels = [0, 1, 2] # 所有标签的列表,必须指定
cm = multilabel_confusion_matrix(y_true, y_pred, labels=labels)
print(cm)
```
输出结果为:
```
[[[0 1]
[1 2]]
[[2 1]
[1 0]]
[[3 0]
[1 0]]]
```
其中,输出结果是一个三维数组,第一维表示标签,第二维表示真实标签或预测标签为正例或负例,第三维表示真实标签或预测标签的数量。例如,第一个标签的混淆矩阵为:
```
[[0 1]
[1 2]]
```
其中,第一行第一列的值为 0,表示真实标签为该标签的样本且预测标签为负例的数量为 0;第一行第二列的值为 1,表示真实标签为该标签的样本且预测标签为正例的数量为 1,以此类推。可见,通过 multilabel_confusion_matrix() 函数可以将多分类问题的混淆矩阵以二值分类混淆矩阵的形式输出。
阅读全文