nn.CrossEntropyLoss()函数的用法
时间: 2024-02-27 18:21:55 浏览: 88
`nn.CrossEntropyLoss()` 是 PyTorch 中用于多分类问题的损失函数。它结合了 `nn.LogSoftmax()` 和 `nn.NLLLoss()`,适用于将模型输出与目标类别进行比较的场景。
`nn.CrossEntropyLoss(weight=None, size_average=None, ignore_index=-100, reduce=None, reduction='mean')`
参数说明:
- `weight`:各个类别的损失权重,可以用于处理类别不平衡问题。默认为 `None`,表示所有类别的损失权重相等。
- `size_average`:已过时的参数,被 `reduce` 参数替代。
- `ignore_index`:忽略指定索引的类别,不计算其损失。默认为 `-100`。
- `reduce`:是否对每个样本的损失进行平均。如果为 `True`,则返回所有样本损失的平均值;如果为 `False`,则返回所有样本损失的总和。默认为 `None`,表示使用全局设置。
- `reduction`:指定如何处理损失。可以是 `'none'`、`'mean'` 或 `'sum'`。默认为 `'mean'`。
使用示例:
```python
import torch
import torch.nn as nn
# 创建交叉熵损失函数
loss_fn = nn.CrossEntropyLoss()
# 假设 outputs 是模型的输出,targets 是目标类别的张量
outputs = torch.randn(10, 5) # 示例中假设模型输出为 10 个样本,每个样本有 5 个类别的得分
targets = torch.randint(5, (10,)) # 示例中假设目标类别为 10 个样本,每个样本的类别在 0 到 4 之间
# 计算损失
loss = loss_fn(outputs, targets)
print(loss)
```
在上述示例中,我们首先创建了一个 `nn.CrossEntropyLoss()` 的实例 `loss_fn`。然后,我们生成了模型的输出张量 `outputs` 和目标类别张量 `targets`。最后,我们将这两个张量作为输入传递给 `loss_fn()` 函数,计算出交叉熵损失 `loss`。
注意:`nn.CrossEntropyLoss()` 函数会自动应用 `nn.LogSoftmax()` 和 `nn.NLLLoss()`,因此在模型的输出中不需要进行 softmax 操作。同时,目标类别张量 `targets` 应当是一维的,每个元素代表一个样本的类别索引。
希望这个示例能够帮助您理解 `nn.CrossEntropyLoss()` 函数的用法。
阅读全文