fn.cross_entropy
时间: 2024-12-31 18:44:37 浏览: 6
### PyTorch 中使用 `cross_entropy` 损失函数
在 PyTorch 中,`torch.nn.CrossEntropyLoss` 或者其功能版本 `torch.nn.functional.cross_entropy` 是用于多类分类问题的标准损失函数之一。此函数结合了 `log_softmax` 和 `NLLLoss` (负对数似然损失),使得实现更加高效。
#### 使用方式
为了应用该损失函数,输入张量应当具有形状 `(minibatch, C)` ,其中 `C` 表示类别数量;目标张量则应为一维张量,长度等于 mini-batch 的大小,每个元素代表对应样本的真实标签索引[^4]。
下面是一个简单的例子来展示如何创建并使用 `CrossEntropyLoss`:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
# 假设我们有三个类别,并构建一些随机数据作为演示用途
data = torch.randn((100, 3)) # 输入特征
labels = torch.randint(0, 3, (100,)).long() # 对应的标签
dataset = TensorDataset(data, labels)
dataloader = DataLoader(dataset)
model = nn.Linear(3, 3) # 创建一个线性层模拟网络结构
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)
for batch_data, batch_labels in dataloader:
optimizer.zero_grad()
outputs = model(batch_data)
loss = criterion(outputs, batch_labels)
loss.backward()
optimizer.step()
```
在这个实例里,定义了一个三分类的任务,并且展示了怎样利用 `nn.CrossEntropyLoss()` 来计算预测值与真实标签之间的差异。注意这里并没有显式地调用 `softmax` 函数,因为 `CrossEntropyLoss` 已经内部处理好了这一点。
阅读全文