nn.CrossEntropyLoss()使用案例
时间: 2024-02-27 10:23:29 浏览: 90
Pytorch中的CrossEntropyLoss()函数案例解读和结合one-hot编码计算Loss
假设你正在解决一个分类问题,有 $n$ 个类别。你的模型输出了一个 $m$ 维的向量,表示每个类别的得分。你的目标是将这个向量转换成一个概率分布,使得每个类别的概率都在 $[0,1]$ 范围内且和为 $1$。你可以使用 softmax 函数将得分向量转换成概率分布。
在使用 softmax 函数之后,你需要将真实标签转换成 one-hot 编码的形式,这样每个标签就可以表示成一个 $n$ 维的向量,其中只有一个元素为 $1$,其余元素都为 $0$。
一旦你得到了模型预测的概率分布和真实标签的 one-hot 编码,你就可以使用交叉熵损失函数来计算模型的损失。交叉熵损失函数的定义如下:
$$
\text{CE}(\mathbf{p}, \mathbf{y}) = -\sum_{i=1}^{n} y_i \log p_i
$$
其中 $\mathbf{p}$ 是模型预测的概率分布,$\mathbf{y}$ 是真实标签的 one-hot 编码。$\log$ 表示自然对数。
在 PyTorch 中,你可以使用 nn.CrossEntropyLoss() 函数来计算交叉熵损失。这个函数的输入是模型预测的得分向量和真实标签的类别索引。具体使用方法如下:
```python
import torch.nn as nn
import torch
# 假设模型输出了一个大小为 (batch_size, n) 的张量,表示每个类别的得分
scores = torch.randn(3, 5)
# 假设真实标签为一个大小为 (batch_size,) 的张量,每个元素是一个类别索引
labels = torch.tensor([2, 0, 4])
# 创建交叉熵损失函数
loss_fn = nn.CrossEntropyLoss()
# 计算损失
loss = loss_fn(scores, labels)
print(loss)
```
输出:
```
tensor(2.2438)
```
上面的例子中,我们假设模型输出了一个大小为 (3, 5) 的张量,表示三个样本在五个类别上的得分。真实标签为一个大小为 (3,) 的张量,每个元素是一个类别索引。我们创建了一个 nn.CrossEntropyLoss() 的实例,然后将模型输出的得分张量和真实标签张量传入函数中计算损失。最终输出的结果是一个标量张量。
阅读全文