focal loss pytorch代码
时间: 2023-08-08 19:10:26 浏览: 178
以下是使用PyTorch实现Focal Loss的代码示例:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, alpha=1, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
inputs = inputs.view(-1)
targets = targets.view(-1)
# 计算交叉熵损失
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
# 计算权重
pt = torch.exp(-ce_loss)
weight = (self.alpha * (1 - pt)**self.gamma)
# 根据权重调整损失
loss = weight * ce_loss
if self.reduction == 'mean':
return loss.mean()
elif self.reduction == 'sum':
return loss.sum()
else:
return loss
# 使用示例
criterion = FocalLoss(alpha=1, gamma=2, reduction='mean')
inputs = torch.randn(10, 3)
targets = torch.randint(0, 3, (10,))
loss = criterion(inputs, targets)
print(loss)
```
在上面的代码中,我们定义了一个继承自`nn.Module`的`FocalLoss`类,它接受`alpha`和`gamma`作为超参数,并且可以选择损失的缩减方式(`reduction='mean'`表示计算平均损失,`reduction='sum'`表示计算总损失)。
在前向传播方法中,我们首先将输入和目标展平,然后计算交叉熵损失。接下来,我们计算每个样本的权重`weight`,并使用该权重调整交叉熵损失。最后,根据缩减方式返回相应的损失值。
在使用时,我们创建一个`FocalLoss`类的实例,并将其作为损失函数使用。输入和目标张量的形状可以根据实际情况进行调整。
请注意,这只是Focal Loss的基本实现示例,你可以根据自己的需求进行修改和扩展。
阅读全文