nn.CrossEntropyLoss()代码举例并注释
时间: 2023-05-11 17:04:20 浏览: 117
当使用神经网络进行分类任务时,我们需要定义一个损失函数来衡量模型预测结果与真实标签之间的差异。在PyTorch中,nn.CrossEntropyLoss()是一个常用的损失函数,它结合了softmax函数和负对数似然损失函数。
下面是一个nn.CrossEntropyLoss()的代码示例:
```python
import torch.nn as nn
# 定义一个分类任务的神经网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.fc1 = nn.Linear(10, 5)
self.fc2 = nn.Linear(5, 2)
def forward(self, x):
x = self.fc1(x)
x = nn.functional.relu(x)
x = self.fc2(x)
return x
# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 假设我们有一个batch_size为3的数据集,每个样本有10个特征和一个标签
inputs = torch.randn(3, 10)
labels = torch.tensor([0, 1, 0])
# 前向传播
net = Net()
outputs = net(inputs)
# 计算损失
loss = criterion(outputs, labels)
# 打印损失值
print(loss.item())
```
在上面的代码中,我们首先定义了一个分类任务的神经网络模型,然后使用nn.CrossEntropyLoss()定义了损失函数。接着,我们生成了一个batch_size为3的数据集,每个样本有10个特征和一个标签。在前向传播过程中,我们将数据集输入到模型中,得到模型的输出。最后,我们使用损失函数计算模型预测结果与真实标签之间的差异,并打印出损失值。
注释:nn.CrossEntropyLoss()的输入参数包括模型的输出和真实标签,输出是一个二维张量,第一维表示batch_size,第二维表示每个类别的得分。真实标签是一个一维张量,每个元素表示对应样本的真实类别。损失函数的输出是一个标量,表示整个batch的平均损失值。
阅读全文