PyTorch Focal Loss 代码
时间: 2023-07-02 16:04:40 浏览: 278
Pytorch 实现focal_loss 多类别和二分类示例
4星 · 用户满意度95%
以下是 PyTorch 中实现 Focal Loss 的代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, gamma=2.0, alpha=None, reduction='mean'):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
if isinstance(alpha, (float, int)): self.alpha = torch.Tensor([alpha, 1 - alpha])
if isinstance(alpha, list): self.alpha = torch.Tensor(alpha)
self.reduction = reduction
def forward(self, inputs, targets):
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
pt = torch.exp(-BCE_loss)
focal_loss = (1 - pt) ** self.gamma * BCE_loss
if self.alpha is not None:
focal_loss = self.alpha[0] * focal_loss * targets + self.alpha[1] * focal_loss * (1 - targets)
if self.reduction == 'mean':
return torch.mean(focal_loss)
elif self.reduction == 'sum':
return torch.sum(focal_loss)
else:
return focal_loss
```
在初始化函数中,我们定义了 Focal Loss 的超参数 gamma 和 alpha,以及损失函数计算的方式 reduction。在 forward 函数中,我们先计算二元交叉熵损失 BCE_loss,然后计算每个样本的 focal loss。最后,我们可以选择性地乘以 alpha 来加权不同类别的损失,并根据 reduction 的方式返回平均损失或总损失。
阅读全文