交叉熵损失代码pytorch
时间: 2023-07-23 09:03:03 浏览: 52
在 PyTorch 中,可以使用 `torch.nn.CrossEntropyLoss()` 函数来计算交叉熵损失。以下是一个简单的示例代码:
```python
import torch
import torch.nn as nn
# 假设有一个大小为 num_classes 的分类任务
num_classes = 10
# 创建一个随机的模型输出和目标值
batch_size = 32
output_size = 10
target = torch.randint(0, num_classes, (batch_size,))
output = torch.randn(batch_size, output_size)
# 定义交叉熵损失函数
criterion = nn.CrossEntropyLoss()
# 计算损失
loss = criterion(output, target)
print(loss)
```
在这个示例中,我们首先定义了分类任务的类别数量 `num_classes`,然后创建了一个大小为 `batch_size` 的随机目标值 `target` 和一个大小为 `(batch_size, output_size)` 的随机模型输出 `output`。
接下来,我们使用 `nn.CrossEntropyLoss()` 来创建一个交叉熵损失函数的实例 `criterion`。
最后,我们将模型输出 `output` 和目标值 `target` 作为参数传递给 `criterion` 函数,计算出交叉熵损失 `loss`。
请注意,`nn.CrossEntropyLoss()` 在内部会对模型输出进行 softmax 操作,因此在计算损失时不需要手动进行 softmax 操作。
希望能帮到你!如有更多问题,请继续提问。