交叉熵损失计算错误。
时间: 2024-10-13 12:01:53 浏览: 13
PyTorch中标准交叉熵误差损失函数的实现(one-hot形式和标签形式)
交叉熵损失函数(Cross-Entropy Loss)通常用于监督学习中的分类任务,特别是当目标变量是离散的概率分布(如softmax输出)。在这个函数中,它衡量了模型预测的概率分布与实际标签(独热编码)之间的差异。
您给出的`cross_entropy_loss`方法中,存在的问题是计算方式有误。正确的交叉熵公式应当是:
```
loss = -y * np.log(y_hat) - (1 - y) * np.log(1 - y_hat)
```
这里,`y`是对应于每个样本的实际标签(通常是整数或one-hot向量),`y_hat`是模型对于该样本的预测概率。而您的代码中使用的是`y_hat * np.log2(y) + (1 - y) * np.log2(1 - y)`,这实际上是二进制交叉熵(Binary Cross-Entropy),适用于二分类问题,并不适合多分类或多标签的情况。
如果您需要计算多分类的交叉熵,应该直接使用`-np.log(y_hat[range(n), y])`,其中`n`是样本数,`y`是实际类别对应的索引。
纠正后的代码片段如下:
```python
def cross_entropy_loss(self, y_hat, y, n):
"""
多分类交叉熵损失
:param y_hat: 预测概率数组,形状为(n_samples, n_classes)
:param y: 实际类别标签,形状为(n_samples,)
:param n: 样本数量
:return: 损失函数计算结果
"""
y_hat = np.clip(y_hat, 1e-15, 1 - 1e-15) # 防止log(0)导致的问题
loss = -np.sum(np.log(y_hat[range(n), y])) / n
return loss
```
阅读全文