计算平均精度
时间: 2023-09-18 18:13:53 浏览: 48
好的,以下是一个简单的计算平均精度的代码示例,使用的是PyTorch框架:
```python
import torch
def compute_ap(predictions, targets):
# 对预测结果按照置信度从大到小排序
sorted_idx = torch.argsort(predictions, descending=True)
# 将对应的标签提取出来
targets = targets[sorted_idx]
# 计算每个位置上的precision和recall
true_positives = targets.float().cumsum(dim=0)
false_positives = (1 - targets).float().cumsum(dim=0)
precision = true_positives / (true_positives + false_positives)
recall = true_positives / targets.sum()
# 将recall从0到1分成11个点
recall_levels = torch.linspace(0, 1, 11)
# 对于每个recall水平,找到最大的precision
# 注意:这里是采用插值的方式计算的,而不是简单的最大值
precisions = torch.zeros_like(recall_levels)
for i, recall_level in enumerate(recall_levels):
recalls_above_level = recall >= recall_level
if recalls_above_level.any():
precisions[i] = precision[recalls_above_level].max()
# 计算平均精度(AP)
ap = precisions.mean()
return ap
```
这个函数接受两个张量作为输入:预测结果(`predictions`)和真实标签(`targets`)。其中,`predictions`是一个大小为`[N]`的一维张量,表示模型对N个样本的预测结果,`targets`是一个大小为`[N]`的一维张量,表示N个样本的真实标签(二进制0/1)。函数返回一个标量,表示平均精度(AP)。