nn.crossentropyloss示例
时间: 2023-12-13 15:05:35 浏览: 96
nn.CrossEntropyLoss是一个用于多分类问题的损失函数,在PyTorch中广泛使用。它结合了softmax激活函数和负对数似然损失,用于衡量模型预测与真实标签之间的差异。
例如,如果我们有一个包含N个类别的分类问题,输入模型的输出是大小为(N,)的张量,每个元素表示该类别的预测概率。真实标签是一个大小为(N,)的张量,其中只有一个元素是1,其余元素都是0,表示真实类别。
nn.CrossEntropyLoss的计算过程如下:
1. 首先,将模型的输出张量通过softmax函数,得到每个类别的预测概率。
2. 然后,根据真实标签的索引,从预测概率张量中取出对应的预测概率。
3. 最后,将取出的预测概率通过负对数函数求取对数似然损失。
相比于手动计算softmax和负对数似然损失,nn.CrossEntropyLoss提供了更简洁和高效的实现方式。
以下是一个nn.CrossEntropyLoss的示例:
```python
import torch
import torch.nn as nn
# 定义模型输出和真实标签
outputs = torch.tensor([[0.2, 0.5, 0.3], [0.8, 0.1, 0.1]])
targets = torch.tensor([0, 1])
# 定义损失函数
loss_func = nn.CrossEntropyLoss()
# 计算损失
loss = loss_func(outputs, targets)
print(loss)
```
输出结果为:
```python
tensor(1.0646)
```
阅读全文