写一个不会出现索引错误的损失函数class FocalLoss(nn.Module)
时间: 2024-04-29 18:21:14 浏览: 189
下面是一个不会出现索引错误的Focal Loss类的实现,其中使用了torch.gather和torch.clamp操作来避免出现索引错误:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
pt = F.softmax(inputs, dim=1)
pt = torch.gather(pt, 1, targets.unsqueeze(1)).squeeze(1)
alpha = torch.ones_like(pt) * self.alpha
alpha = torch.where(torch.eq(targets, 1), alpha, 1 - alpha)
loss = -alpha * torch.pow(1 - pt, self.gamma) * torch.log(pt)
loss = torch.clamp(loss, min=-1e4, max=1e4) # 避免出现nan
if self.reduction == 'mean':
loss = torch.mean(loss)
elif self.reduction == 'sum':
loss = torch.sum(loss)
return loss
```
使用时,可以按照下面的方式调用:
```python
criterion = FocalLoss(alpha=0.25, gamma=2, reduction='mean')
loss = criterion(inputs, targets)
```
阅读全文