python中交叉熵损失函数用法
时间: 2024-08-19 10:00:44 浏览: 53
在Python中,特别是在深度学习框架如TensorFlow或PyTorch中,交叉熵损失函数通常用于分类任务,特别是二分类或多分类问题。它衡量了模型预测概率分布与真实标签之间差异的程度。
**TensorFlow 示例**:
```python
import tensorflow as tf
# 假设y_true是one-hot编码的真实类别(shape = [batch_size, num_classes])
y_true = tf.keras.utils.to_categorical(y, num_classes)
# y_pred是模型的预测概率分布(同样shape)
y_pred = model.predict(x)
# 使用tf.nn.softmax_cross_entropy_with_logits_v2计算交叉熵
cross_entropy_loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_true, logits=y_pred)
loss = tf.reduce_mean(cross_entropy_loss) # 计算平均交叉熵
```
**PyTorch 示例**:
```python
import torch
# 假设y_true是一个LongTensor,表示真实标签
y_true = torch.LongTensor([0, 1, 2]) # 对应于三个类别的标签
# y_pred是模型输出的概率分布,需要先通过softmax转换成概率
y_pred = torch.softmax(model(x), dim=1)
# 使用torch.nn.functional.cross_entropy计算交叉熵
criterion = nn.CrossEntropyLoss() # 创建交叉熵损失函数实例
loss = criterion(y_pred, y_true)
```
阅读全文