pytorch AUC/ROC
时间: 2024-08-12 17:06:06 浏览: 60
PyTorch中的AUC(Area Under the Curve)和ROC(Receiver Operating Characteristic)曲线是用来评估二分类模型性能的重要指标。AUC衡量的是真正例率(TPR,True Positive Rate)随着假正例率(FPR,False Positive Rate)变化的曲线下面积,值越大表示模型区分正负样本的能力越强。
ROC曲线是一条以FPR为横坐标、TPR为纵坐标的曲线,它显示了模型在所有可能的阈值下,各种FPR和TPR对的组合。45度线代表随机猜测的结果,而完全分开的曲线(TPR=1-FPR)意味着完美预测。
在PyTorch中,计算AUC通常通过`sklearn.metrics.roc_auc_score`函数,首先你需要对模型的预测概率或输出进行排序,并计算对应的真阳性率和假阳性率,然后将它们传入该函数计算AUC值。你可以使用`torch.sigmoid`函数获取softmax预测的概率,对于`logits`则需要先转换。
示例代码:
```python
import torch
from sklearn import metrics
# 假设predictions是模型的sigmoid输出或logits
predictions = model(inputs).sigmoid() # 对于logits需要用sigmoid转换
true_labels = torch.tensor([0, 1, 0, 1]) # 真实标签
# 计算AUC和ROC曲线数据
fpr, tpr, _ = metrics.roc_curve(true_labels, predictions)
auc = metrics.roc_auc_score(true_labels, predictions)
# 可视化ROC曲线
plt.plot(fpr, tpr, label=f"AUC = {auc:.2f}")
```
阅读全文