huber损失函数代码
时间: 2023-12-25 09:58:03 浏览: 157
损失函数为Cross entropy的手写数字识别神经网络代码与实现
以下是 Huber 损失函数的 Python 代码实现:
```python
import numpy as np
def huber_loss(y_true, y_pred, delta=1):
"""
计算 Huber 损失函数
:param y_true: 实际值
:param y_pred: 预测值
:param delta: 阈值,用于判断是 L1 损失还是 L2 损失
:return: Huber 损失
"""
error = y_true - y_pred
if np.abs(error) <= delta:
# L2 损失
loss = 0.5 * error ** 2
else:
# L1 损失
loss = delta * (np.abs(error) - 0.5 * delta)
return loss
```
其中,`y_true` 为实际值,`y_pred` 为预测值,`delta` 为阈值,用于判断是 L1 损失还是 L2 损失。当误差小于等于阈值时,使用 L2 损失;当误差大于阈值时,使用 L1 损失。
阅读全文