用python帮我写一个在loss.py中的Focal loss损失函数
时间: 2024-05-14 22:15:17 浏览: 156
以下是一个简单的Focal loss损失函数的python代码实现:
```
import tensorflow as tf
def focal_loss(y_true, y_pred, gamma=2.0, alpha=0.25):
"""Focal loss for multi-class classification
Args:
y_true: A tensor of shape [batch_size, num_classes] containing ground truth labels
y_pred: A tensor of shape [batch_size, num_classes] containing predicted logits
gamma: The focusing parameter. Default is 2.0
alpha: The weighting parameter. Default is 0.25 for imbalanced datasets
Returns:
A tensor of shape [batch_size] containing the focal loss for each sample in the batch
"""
# Compute cross entropy loss
ce_loss = tf.nn.softmax_cross_entropy_with_logits(labels=y_true, logits=y_pred)
# Compute softmax probability
prob = tf.nn.softmax(y_pred)
# Compute the weight for each sample
weight = tf.multiply(y_true, tf.pow(tf.subtract(1., prob), gamma))
# Compute the final loss
fl_loss = tf.multiply(alpha, tf.multiply(weight, ce_loss))
# Sum the losses over all classes
fl_loss = tf.reduce_sum(fl_loss, axis=1)
return fl_loss
```
使用时,可以将该函数作为损失函数传递给模型的compile方法:
```
model.compile(optimizer='adam', loss=focal_loss, metrics=['accuracy'])
```
阅读全文