torch.nn.CrossEntropyLoss()函数
时间: 2024-02-27 10:15:51 浏览: 87
torch.nn.CrossEntropyLoss() 是一个用于计算交叉熵损失的函数。在深度学习中,交叉熵损失函数通常用于多分类任务。
交叉熵损失函数的输入包括两个部分:模型的输出和真实标签。它首先将模型的输出通过 softmax 函数转换为概率分布,然后计算预测标签与真实标签之间的交叉熵。
具体而言,torch.nn.CrossEntropyLoss() 函数接受两个输入参数:
- input: 模型的输出,通常是一个 (N, C) 的张量,N 表示样本数量,C 表示类别数量。每一行代表一个样本的预测概率分布。
- target: 真实标签,通常是一个 (N,) 的长整型张量,每个元素表示对应样本的真实类别。
该函数会返回一个标量值,表示整个样本的平均交叉熵损失。
使用示例:
```python
import torch
import torch.nn as nn
# 假设模型有3个类别
model_output = torch.tensor([[0.1, 0.2, 0.7], [0.3, 0.4, 0.3]])
true_labels = torch.tensor([2, 0]) # 真实类别为2和0
loss_fn = nn.CrossEntropyLoss()
loss = loss_fn(model_output, true_labels)
print(loss) # 输出损失值
```
这里的模型输出 `model_output` 是一个 (2, 3) 的张量,表示两个样本属于每个类别的概率。真实标签 `true_labels` 是一个 (2,) 的张量,表示每个样本的真实类别。函数计算出的损失值 `loss` 是一个标量值。
阅读全文