nn.CrossEntropyLoss()中weight参数的使用例子
时间: 2024-04-22 16:18:28 浏览: 123
nn.CrossEntropyLoss()是PyTorch中用于计算交叉熵损失的函数。它的weight参数用于对不同类别的样本进行加权处理,可以用来处理类别不平衡的情况。
例如,假设我们有一个二分类问题,其中类别1的样本数量远远多于类别2的样本数量。我们可以使用weight参数来平衡两个类别的损失。
```python
import torch
import torch.nn as nn
# 假设我们有一个二分类问题,共有10个样本,其中5个属于类别1,5个属于类别2
# 类别1的样本权重为1,类别2的样本权重为2
target = torch.tensor([1, 1, 1, 1, 1, 2, 2, 2, 2, 2])
weight = torch.tensor([1, 2])
# 定义交叉熵损失函数,并传入weight参数
loss_fn = nn.CrossEntropyLoss(weight=weight)
# 假设模型输出的预测结果为:
# 类别1的预测概率为 [0.2, 0.8, 0.6, 0.4, 0.3]
# 类别2的预测概率为 [0.8, 0.2, 0.4, 0.6, 0.7]
output = torch.tensor([[0.2, 0.8], [0.8, 0.2], [0.6, 0.4], [0.4, 0.6], [0.3, 0.7]])
# 计算损失
loss = loss_fn(output, target)
print(loss)
```
在上述例子中,我们通过设置weight参数为[1, 2],对类别1和类别2的样本进行了加权处理。由于类别2的样本权重为2,因此对应的损失会有更大的贡献。这样可以平衡类别不平衡问题,使得模型更加关注少数类别的样本。
阅读全文