帮我写一个focal loss
时间: 2023-11-27 11:56:24 浏览: 32
好的,下面是一个简单的 PyTorch 实现:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=None, reduction='mean'):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
self.reduction = reduction
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = (1 - pt) ** self.gamma * ce_loss
if self.alpha is not None:
alpha = self.alpha[targets]
focal_loss = alpha * focal_loss
if self.reduction == 'mean':
return focal_loss.mean()
elif self.reduction == 'sum':
return focal_loss.sum()
else:
return focal_loss
```
其中,gamma 表示 Focal Loss 中的一个超参数,alpha 可以用来调整不同类别的权重,reduction 表示损失函数的计算方式(mean 或 sum)。在 forward 函数中,首先计算 cross-entropy loss,然后计算 focal loss,最后根据 alpha 和 reduction 进行处理并返回。