我想利用pytorch计算准确率、精度、召回率和 F1 值
时间: 2023-09-11 09:08:35 浏览: 186
好的,下面是如何使用 PyTorch 计算准确率、精度、召回率和 F1 值的示例代码:
```python
import torch
def accuracy(y_pred, y_true):
"""计算准确率"""
preds = torch.argmax(y_pred, dim=1)
correct = torch.eq(preds, y_true).sum().item()
acc = correct / len(y_true)
return acc
def precision(y_pred, y_true):
"""计算精度"""
preds = torch.argmax(y_pred, dim=1)
true_positives = torch.eq(preds, y_true) & torch.eq(preds, 1)
false_positives = torch.ne(preds, y_true) & torch.eq(preds, 1)
p = true_positives.sum().item() / (true_positives.sum().item() + false_positives.sum().item())
return p
def recall(y_pred, y_true):
"""计算召回率"""
preds = torch.argmax(y_pred, dim=1)
true_positives = torch.eq(preds, y_true) & torch.eq(preds, 1)
false_negatives = torch.ne(preds, y_true) & torch.eq(y_true, 1)
r = true_positives.sum().item() / (true_positives.sum().item() + false_negatives.sum().item())
return r
def f1_score(y_pred, y_true):
"""计算 F1 值"""
p = precision(y_pred, y_true)
r = recall(y_pred, y_true)
f1 = 2 * (p * r) / (p + r)
return f1
```
这里假设 `y_pred` 是模型的预测输出,`y_true` 是真实的标签。您可以将这些函数添加到您的代码中,并使用它们来计算指标。
阅读全文