写出一段不平衡样本的损失函数代码
时间: 2023-06-11 16:09:10 浏览: 43
假设我们要解决一个二分类问题,其中正样本有较少的比例,可以使用 Focal Loss,它会对正样本的损失进行放大,使得模型更加关注于正样本的分类准确性。
以下是使用 PyTorch 实现的 Focal Loss 的代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class FocalLoss(nn.Module):
def __init__(self, gamma=2, alpha=None):
super().__init__()
self.gamma = gamma
self.alpha = alpha
def forward(self, input, target):
ce_loss = F.cross_entropy(input, target, reduction='none')
pt = torch.exp(-ce_loss)
if self.alpha is not None:
alpha_t = self.alpha[target]
pt = alpha_t * pt
loss = (1 - pt) ** self.gamma * ce_loss
return loss.mean()
```
其中,`gamma` 是一个超参数,控制着正样本损失的放大程度,`alpha` 是一个可选参数,可以用于进一步调整正样本的权重。在使用时,可以将正样本设置为较低的权重,如 `alpha = [0.1, 0.9]`,使得模型更加关注于正样本的分类准确性。