python通过tp,fp,fn计算pr以及AP
时间: 2024-01-10 14:04:34 浏览: 74
setting.xml文件,修改Maven仓库指向至阿里仓
PR指的是精确率(Precision)和召回率(Recall)的曲线,AP指的是平均精度(Average Precision)。
首先,我们需要计算出每个不同的阈值下的TP、FP和FN的数量,可以使用如下代码:
```python
def compute_counts(true_labels, pred_scores, threshold):
tp, fp, fn = 0, 0, 0
for i in range(len(true_labels)):
if pred_scores[i] >= threshold and true_labels[i] == 1:
tp += 1
elif pred_scores[i] >= threshold and true_labels[i] == 0:
fp += 1
elif pred_scores[i] < threshold and true_labels[i] == 1:
fn += 1
return tp, fp, fn
```
其中,`true_labels`是一个包含0和1的列表,表示每个样本的真实标签;`pred_scores`是一个列表,表示每个样本的预测分数;`threshold`是阈值,表示分数大于等于该阈值的样本被视为正例。
接下来,我们可以计算出每个不同的阈值下的精确率和召回率:
```python
def compute_pr(true_labels, pred_scores):
thresholds = sorted(set(pred_scores), reverse=True)
pr = []
for threshold in thresholds:
tp, fp, fn = compute_counts(true_labels, pred_scores, threshold)
precision = tp / (tp + fp) if tp + fp > 0 else 0
recall = tp / (tp + fn) if tp + fn > 0 else 0
pr.append((precision, recall))
return pr
```
其中,`thresholds`是一个列表,包含了所有不同的阈值,按照从大到小的顺序排列;`pr`是一个列表,包含了每个不同的阈值下的精确率和召回率。我们可以使用这个函数来计算PR曲线。
最后,我们可以使用下面的代码来计算AP:
```python
def compute_ap(pr):
ap = 0
for i in range(len(pr) - 1):
ap += (pr[i+1][1] - pr[i][1]) * (pr[i][0] + pr[i+1][0]) / 2
return ap
```
其中,`pr`是一个包含了每个不同的阈值下的精确率和召回率的列表。我们可以使用这个函数来计算AP。
阅读全文