python求得交叉熵损失值代码
时间: 2023-05-16 18:06:48 浏览: 57
以下是 Python 中求交叉熵损失值的代码:
import numpy as np
def cross_entropy_loss(y_pred, y_true):
"""
计算交叉熵损失值
:param y_pred: 预测值,形状为 (batch_size, num_classes)
:param y_true: 真实值,形状为 (batch_size, num_classes)
:return: 交叉熵损失值
"""
# 防止出现 log(0) 的情况,加上一个极小值
eps = 1e-15
# 计算交叉熵损失值
loss = -np.sum(y_true * np.log(y_pred + eps))
# 返回平均损失值
return loss / y_pred.shape[0]
相关问题
中值频率平衡 交叉熵损失
### 中值频率平衡在交叉熵损失函数中的应用
对于类别不平衡的数据集,在训练分类器时,直接使用标准的交叉熵损失可能导致模型偏向多数类而忽略少数类。为此,引入加权交叉熵损失是一种有效的解决方案。
#### 加权交叉熵损失公式
加权交叉熵损失通过为不同类别的样本分配不同的权重来调整损失函数的影响程度:
\[ L_{weighted} = -\sum_{i=1}^{N}\sum_{c=1}^{C} w_c y_i^c \log(p_i^c) \]
其中 \(w_c\) 表示第 c 类的权重, \(y_i^c\) 是真实标签向量,\(p_i^c\) 则是预测概率分布[^1]。
#### 权重计算方式——中值频率平衡
一种常见的权重设定策略即基于中值频率平衡的方法。具体而言,先统计各个类别的像素数量并求得其对应的频率;接着找出这些频率值的中位数 `median_freq`;最后利用该中位数值分别去除各类别原始频率值得到最终用于加权的系数列表[^4]。
此过程可表示如下:
设某数据集中共有 C 个类别,则每种类别 j 的频率 freq(j),以及所有类别频率组成的数组 Freqs[] 可以这样定义:
\[freq(j)=\frac{num\_of\_pixels(class=j)}{\text {total number of pixels}}\]
随后获取Freqs[] 数组内的元素中位数 median_freq ,再针对每一个类别j 计算相应的权重 weight[j]:
\[weight[j]=\frac{median\_freq}{freq(j)}\]
上述操作确保了即使某些特定类型的实例较少也能获得相对较高的重视度,从而改善整体识别效果。
#### Python 实现代码
下面给出一段简单的Python代码片段展示如何实现这一机制:
```python
import numpy as np
def calculate_weights(label_counts):
"""
Calculate class weights using the median frequency balancing method.
Args:
label_counts (list or array): Number of occurrences for each class
Returns:
list: Weights corresponding to input classes
"""
total_pixels = sum(label_counts)
frequencies = [count / float(total_pixels) for count in label_counts]
median_frequency = np.median(frequencies)
# Compute and normalize weights based on inverse frequency ratio with respect to median value
weights = []
for f in frequencies:
if f != 0:
weights.append(median_frequency / f)
else:
weights.append(0.) # Avoid division by zero when a category has no samples
return weights
# Example usage
label_distribution = [8000, 2000, 500, 300, 100] # Hypothetical distribution across five categories
class_weights = calculate_weights(label_distribution)
print("Class Weights:", class_weights)
```
机器学习逻辑回归交叉熵损失
### 逻辑回归中的交叉熵损失函数
#### 背景介绍
在机器学习领域,尤其是针对二分类问题时,逻辑回归是一种广泛应用的方法。逻辑回归不仅能够给出类别预测的结果,还能提供该结果的概率估计。为了训练这样的模型并调整其权重参数,需要定义一个有效的损失函数来衡量模型预测值与真实标签之间的差异。
#### 交叉熵损失函数的作用
对于分类任务而言,相比于平方差误差(SSE),采用交叉熵作为损失函数具有明显优势[^1]。这是因为,在处理多类别的概率分布情况下,SSE可能会导致梯度消失现象,从而阻碍了有效更新权值;而交叉熵则能更好地捕捉到不同类别间的真实差距,并且有助于加速收敛速度以及提高最终性能表现。
具体来说,给定一对输入特征向量 \( \mathbf{x} \) 和对应的二元目标变量 \( y\in{0,1} \), 如果使用逻辑回归建模,则输出为:
\[ p(y=1|\mathbf{x};\theta)=h_\theta(\mathbf{x})=\frac{1}{1+\exp(-\theta^\top\mathbf{x})}, \]
其中 \( h_\theta(\cdot)\ ) 表示假设函数,\( \theta \) 是待估参数向量。此时,单样本上的交叉熵损失可表示为:
\[ L_{CE}(y,\hat{y})=-[y\log{\hat{y}}+(1-y)\log{(1-\hat{y})}], \]
这里 \( \hat{y}=h_\theta(x) \).
此表达式的直观意义在于:当实际标签接近于某个极端值 (即几乎完全属于某一类) ,那么希望模型对该实例也作出相似程度的确信判断; 反之亦然.
#### 计算过程展示
下面是一个简单的Python实现例子,展示了如何利用PyTorch框架构建带有交叉熵损失的逻辑回归模型来进行MNIST手写数字识别任务的一部分代码片段:
```python
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn.functional as F
class LogisticRegression(torch.nn.Module):
def __init__(self, input_dim, output_dim):
super(LogisticRegression, self).__init__()
self.linear = torch.nn.Linear(input_dim, output_dim)
def forward(self, x):
outputs = torch.sigmoid(self.linear(x))
return outputs
def train(model, device, train_loader, optimizer, epoch):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device).view(data.shape[0], -1), target.to(device)
optimizer.zero_grad() # 清除之前的梯度
output = model(data.float()) # 前向传播获取预测结果
loss = F.binary_cross_entropy(output.squeeze(), target.float()) # 使用内置方法计算交叉熵损失
loss.backward() # 后向传播计算梯度
optimizer.step() # 更新参数
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.item()))
```
这段程序首先导入必要的库文件,接着定义了一个继承自`torch.nn.Module` 的 `LogisticRegression` 类用来创建线性变换层加上 Sigmoid 激活后的前馈神经网络结构。之后编写了名为 `train()` 的辅助函数负责完成一轮完整的训练流程——包括加载批次数据、执行正向传递获得当前状态下的预测得分、依据这些分数调用 PyTorch 提供好的接口快速求得 CE Loss 并反传累积起来的梯度最后再一步到位地实施一次SGD 来修正所有参与运算过的张量内的数值。
阅读全文
相关推荐











