python绘制多分类的AUC曲线
时间: 2023-10-04 21:04:41 浏览: 100
首先,要绘制多分类的AUC曲线,需要将问题转换成多个二分类问题。一种常见的方法是将每个类别与其他所有类别分别作为正例和负例,计算每个类别的AUC值。然后,可以将所有类别的AUC值综合起来画出多分类的AUC曲线。
以下是一个示例代码,使用了scikit-learn库中的multiclass_roc_auc_score()函数来计算多分类AUC值,并使用matplotlib库绘制曲线:
```python
import numpy as np
import matplotlib.pyplot as plt
from itertools import cycle
from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from sklearn.datasets import make_classification
# 生成样本数据
X, y = make_classification(n_samples=1000, n_features=10, n_classes=5, n_informative=5,
n_redundant=0, n_clusters_per_class=1, random_state=42)
# 将类别转化为二值标签
y_bin = label_binarize(y, classes=list(range(5)))
# 训练多个二分类模型,并计算每个类别的AUC值
fpr = dict()
tpr = dict()
roc_auc = dict()
n_classes = y_bin.shape[1]
for i in range(n_classes):
clf = OneVsRestClassifier(svm.SVC(kernel='linear', probability=True, random_state=42))
clf.fit(X, y_bin[:, i])
y_score = clf.decision_function(X)
fpr[i], tpr[i], _ = roc_curve(y_bin[:, i], y_score)
roc_auc[i] = auc(fpr[i], tpr[i])
# 计算平均AUC值
roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])
roc_auc["macro"] = np.mean(list(roc_auc.values()))
# 绘制多分类AUC曲线
plt.figure()
lw = 2
colors = cycle(['aqua', 'darkorange', 'cornflowerblue', 'green', 'red'])
for i, color in zip(range(n_classes), colors):
plt.plot(fpr[i], tpr[i], color=color, lw=lw,
label='ROC curve of class {0} (AUC = {1:0.2f})'
''.format(i, roc_auc[i]))
plt.plot(fpr["micro"], tpr["micro"], color='deeppink',
label='micro-average ROC curve (AUC = {0:0.2f})'
''.format(roc_auc["micro"]),
linestyle=':', linewidth=4)
plt.plot(fpr["macro"], tpr["macro"], color='navy',
label='macro-average ROC curve (AUC = {0:0.2f})'
''.format(roc_auc["macro"]),
linestyle=':', linewidth=4)
plt.plot([0, 1], [0, 1], 'k--', lw=lw)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Multi-class ROC Curve')
plt.legend(loc="lower right")
plt.show()
```
注意,这里使用了OneVsRestClassifier()函数将多分类问题转化为多个二分类问题,SVM作为分类器。你可以根据实际情况选择不同的分类器。
阅读全文