pytorch序列推荐计算hit ratio@5的代码
时间: 2023-10-25 12:02:53 浏览: 71
下面是使用PyTorch计算hit ratio@5的代码示例:
```python
import torch
def hit_ratio_at_k(predictions, targets, k):
_, indices = torch.topk(predictions, k, dim=1)
targets = targets.unsqueeze(1)
hits = torch.sum(torch.eq(indices, targets), dim=1)
hr = torch.mean((hits > 0).float())
return hr
# 示例数据
predictions = torch.tensor([[0.2, 0.3, 0.4, 0.6, 0.1],
[0.3, 0.1, 0.5, 0.4, 0.2],
[0.1, 0.5, 0.3, 0.2, 0.4]])
targets = torch.tensor([2, 4, 1])
k = 5
hr = hit_ratio_at_k(predictions, targets, k)
print(f"HIT ratio@{k}: {hr:.2f}")
```
这段代码计算了一个包含3个样本的预测结果和对应的目标值的示例数据。其中`predictions`是一个包含每个样本推荐得分的矩阵,`targets`是对应每个样本的目标值。`hit_ratio_at_k`函数接收预测结果和目标值作为输入,计算命中率。首先,使用`torch.topk`函数获取每个样本分数最高的k个索引。然后,对目标值进行unsqueeze操作,以匹配索引的维度。接着,通过比较indices和目标值的相等性得到命中数,再计算平均命中率。最后,输出计算结果。
以上示例代码展示了使用PyTorch计算hit ratio@5的方法,你可以根据不同的需求和数据类型进行相应的调整和优化。