交叉熵损失 weight
时间: 2025-01-05 08:33:38 浏览: 6
### 交叉熵损失函数中的 `weight` 参数
在处理不平衡数据集时,`weight` 参数可以用来调整不同类别之间的权重。通过赋予较少样本类更高的权重,可以使模型更加关注这些类别的预测准确性。
#### PyTorch 中的 `weight` 参数使用方法
PyTorch 的 `CrossEntropyLoss` 和 `NLLLoss` 都支持 `weight` 参数。该参数接受一个一维张量,其长度等于类别的数量。每个元素对应相应类别的权重值[^1]。
```python
import torch
import torch.nn as nn
# 定义三个类别的权重
weights = torch.tensor([0.1, 0.3, 0.6])
# 创建带权重的交叉熵损失函数实例
loss_fn = nn.CrossEntropyLoss(weight=weights)
# 假设输入 logits 和目标标签如下
logits = torch.randn(3, 3) # batch_size=3, num_classes=3
labels = torch.tensor([0, 1, 2]) # 类别索引
# 计算加权后的损失
loss = loss_fn(logits, labels)
print(loss.item())
```
当设置不同的权重时,会对最终计算出来的损失产生影响。具体来说:
- 权重较大的类别,在总损失中所占比例会更大;
- 对于少数类分配较大权重有助于缓解类别不均衡带来的问题。
#### Keras/TensorFlow 中的 `class_weight` 参数
对于 TensorFlow/Keras 用户而言,则可以通过 `fit()` 方法中的 `class_weight` 参数来指定每种类别的权重[^2]。
```python
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from sklearn.utils.class_weight import compute_class_weight
model = Sequential([
Dense(64, activation='relu', input_shape=(input_dim,)),
Dense(num_classes, activation='softmax')
])
# 自动计算类别权重
y_train_integers = np.argmax(y_train, axis=1)
class_weights = compute_class_weight('balanced',
classes=np.unique(y_train_integers),
y=y_train_integers)
# 将 numpy 数组转换成字典形式
class_weight_dict = dict(enumerate(class_weights))
# 编译并训练模型
model.compile(optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'])
history = model.fit(X_train,
y_train,
epochs=num_epochs,
class_weight=class_weight_dict)
```
阅读全文