circle loss代码实现_Focal loss的简单实现(二分类+多分类)
时间: 2023-09-27 14:06:35 浏览: 274
下面是 circle loss 和 focal loss 的简单实现(包括二分类和多分类):
Circle Loss:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class CircleLoss(nn.Module):
def __init__(self, m=0.25, s=30):
super(CircleLoss, self).__init__()
self.m = m
self.s = s
def forward(self, feats, labels):
sim_mat = torch.matmul(feats, feats.t())
mask = labels.expand(labels.size(0), labels.size(0)).t().eq(labels.expand(labels.size(0), labels.size(0)))
pos_mask = mask.triu(diagonal=1)
neg_mask = mask.logical_not().triu(diagonal=1)
pos_sim = sim_mat[pos_mask]
neg_sim = sim_mat[neg_mask]
alpha_p = F.relu(-pos_sim.detach() + 1 + self.m)
alpha_n = F.relu(neg_sim.detach() + self.m)
delta_p = 1 - self.m
delta_n = self.m
logit_p = -self.s * alpha_p * (pos_sim - delta_p)
logit_n = self.s * alpha_n * (neg_sim - delta_n)
logit = torch.cat([logit_p, logit_n], dim=0)
loss = F.softplus(torch.logsumexp(logit, dim=0))
return loss
```
Focal Loss:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, alpha=0.25, gamma=2, reduction='mean'):
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs, targets):
BCE_loss = F.binary_cross_entropy_with_logits(inputs, targets, reduction='none')
pt = torch.exp(-BCE_loss)
F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
if self.reduction == 'mean':
return torch.mean(F_loss)
elif self.reduction == 'sum':
return torch.sum(F_loss)
else:
return F_loss
```
以上代码适用于 PyTorch 深度学习框架。
阅读全文