已有所有logits和预测标签,给我一个代码画ROC和PR曲线
时间: 2024-02-24 10:56:46 浏览: 130
好的,下面是一个示例代码,可以根据给定的预测标签和真实标签画出 ROC 曲线和 PR 曲线:
```
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, precision_recall_curve
# 假设已有所有 logits 和预测标签 pred_labels,以及真实标签 true_labels
logits = np.random.randn(100)
pred_labels = (logits > 0).astype(int)
true_labels = np.random.randint(0, 2, size=100)
# 计算 ROC 曲线和 AUC
fpr, tpr, thresholds = roc_curve(true_labels, logits)
roc_auc = auc(fpr, tpr)
# 绘制 ROC 曲线
plt.plot(fpr, tpr, color='darkorange', lw=2, label='ROC curve (area = %0.2f)' % roc_auc)
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver operating characteristic')
plt.legend(loc="lower right")
plt.show()
# 计算 PR 曲线和 AUC
precision, recall, thresholds = precision_recall_curve(true_labels, logits)
pr_auc = auc(recall, precision)
# 绘制 PR 曲线
plt.plot(recall, precision, color='darkorange', lw=2, label='PR curve (area = %0.2f)' % pr_auc)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall curve')
plt.legend(loc="lower left")
plt.show()
```
在这个代码中,我们首先利用 `sklearn.metrics` 中的 `roc_curve` 和 `precision_recall_curve` 函数计算 ROC 曲线和 PR 曲线上的点,然后分别利用 `auc` 函数计算曲线下的面积,最后利用 `matplotlib` 绘制出曲线。在绘制 ROC 曲线时,我们还绘制了一个对角线来表示随机猜测的结果。
阅读全文