帮我写一个Focal loss代码
时间: 2023-05-29 14:07:40 浏览: 110
focal loss.py
以下是一个简单的Focal loss代码实现:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2.0):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
def forward(self, inputs, targets):
targets = targets.view(-1, 1)
logpt = F.log_softmax(inputs, dim=1)
logpt = logpt.gather(1, targets)
logpt = logpt.view(-1)
pt = logpt.exp()
alpha = torch.ones(inputs.size(1)) * self.alpha
alpha[targets.data] = 1 - self.alpha
weights = alpha.gather(0, targets.view(-1))
weights = weights.view(-1)
loss = -weights * (1 - pt) ** self.gamma * logpt
return loss.mean()
```
这里的alpha是类别权重,gamma是调节难易样本权重的超参数。在前向传播中,首先将targets转换为列向量,然后计算log_softmax,接着用gather函数根据targets选择对应的logpt值。然后,根据alpha和targets计算权重weights,最后计算Focal loss并返回平均值。
阅读全文