equalized focal loss代码
时间: 2024-12-30 21:21:45 浏览: 5
### Equalized Focal Loss Code Implementation
Equalized Focal Loss aims to address class imbalance and hard example mining more effectively by adjusting the standard Focal Loss formulation. The following Python code demonstrates an implementation of equalized focal loss based on existing research advancements[^1]:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class EqualizedFocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2.0, num_classes=80, reduction='mean'):
super(EqualizedFocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.num_classes = num_classes
self.reduction = reduction
def forward(self, logits, targets):
# Calculate probabilities from logits
probas = F.softmax(logits, dim=-1)
# Create one-hot encoding for target classes
y_onehot = F.one_hot(targets, num_classes=self.num_classes).float()
# Compute weights inversely proportional to frequency
freq_weights = 1 / ((y_onehot.sum(dim=0) + 1e-6) ** 0.5)
# Normalize frequencies so they sum up to number of samples
norm_freqs = freq_weights * (targets.shape[0] / freq_weights.sum())
# Apply normalization factor per sample
eq_factor = norm_freqs.gather(1, targets.unsqueeze(-1)).squeeze()
# Standard FL component
ce_loss = F.cross_entropy(logits, targets, reduction="none")
p_t = probas.gather(1, targets.unsqueeze(-1)).squeeze() + 1e-9
fl_modulating_factor = (1 - p_t)**self.gamma
balanced_fl_weight = self.alpha * y_onehot + (1-self.alpha)*(1-y_onehot)
# Combine all components into final EFL formula
ef_loss = eq_factor * balanced_fl_weight.gather(
1, targets.unsqueeze(-1)).squeeze() * \
fl_modulating_factor * ce_loss
if self.reduction == 'mean':
return ef_loss.mean()
elif self.reduction == 'sum':
return ef_loss.sum()
else:
return ef_loss
```
This implementation introduces a new term `eq_factor` which adjusts each training instance's contribution according to its rarity within the dataset.
阅读全文