focal loss代码实现
时间: 2023-09-12 22:03:50 浏览: 110
以下是 focal loss 的代码实现(Python):
```python
import tensorflow as tf
def focal_loss(y_true, y_pred, gamma=2.0, alpha=0.25):
"""
Focal loss for multi-class classification.
Args:
y_true (tensor): True labels of shape (batch_size, num_classes).
y_pred (tensor): Predicted labels of shape (batch_size, num_classes).
gamma (float): Focusing parameter. Default is 2.0.
alpha (float): Weighting factor for positive class. Default is 0.25.
Returns:
Focal loss for the batch.
"""
# Convert true labels to one-hot encoding
y_true = tf.one_hot(tf.cast(y_true, tf.int32), depth=y_pred.shape[-1])
# Calculate cross-entropy loss
pt = tf.where(tf.equal(y_true, 1), y_pred, 1 - y_pred)
ce_loss = -tf.reduce_sum(alpha * tf.pow(1 - pt, gamma) * tf.log(pt), axis=1)
# Calculate focal loss
fl_loss = tf.reduce_mean(ce_loss)
return fl_loss
```
这个实现使用了 TensorFlow 库实现,支持多类别分类问题。gamma 参数控制背景类别相对于前景类别的权重,alpha 参数调节正类别权重。在实现中,先将真实标签转换为 one-hot 编码,然后计算交叉熵损失。最后使用 focal loss 更新损失值。
阅读全文