Pytorch损失函数nn.NLLLoss2d()用法说明
在PyTorch中,损失函数是衡量模型预测与真实标签之间差距的重要工具,nn.NLLLoss2d()是其中一种损失函数,主要用于二维数据,如图像处理任务。它全称为Negative Log Likelihood Loss(负对数似然损失),在分类问题中广泛应用。 nn.NLLLoss2d()的主要功能是计算每个像素的负对数似然损失,然后对整个批次的像素进行平均,以获得整体损失。在使用nn.NLLLoss2d()之前,通常需要先通过nn.Softmax()函数计算每个像素类别的概率分布,然后再应用nn.LogSoftmax()得到对数概率。 nn.NLLLoss2d()的输入参数包括两个:预测输出张量和目标标签张量。预测输出张量的形状通常是[B, C, H, W],其中B是批次大小,C是类别数,H和W是高度和宽度。目标标签张量的形状是[B, H, W],其中每个像素值对应于预测输出中的类别索引。 在例子中,x是预测输出,其形状首先被调整为[B, C, H, W],然后通过nn.Softmax(dim=1)和nn.LogSoftmax(dim=1)转换为对数概率。y是目标标签,调整为[B, H, W],其中每个像素值代表其对应的类别。nn.NLLLoss2d()计算损失时,会对每个像素的负对数似然进行平均,最后再对批次平均。 例如,如果有一个形状为[B, 3, 3]的目标标签y,那么对于每个像素,nn.NLLLoss2d()会找到预测概率向量中对应于标签的元素,计算其对数,并取负值。然后将所有像素的损失相加,除以像素总数(这里是9),得到单个批次的平均损失。 补充知识中展示了另一个例子,使用了nn.NLLLoss2d()的完整流程。创建了输入张量inputs_tensor和目标张量targets_tensor,分别表示预测和实际标签。接着,将输入张量通过nn.LogSoftmax()转换为对数概率,目标张量则保持不变。通过nn.NLLLoss2d()计算损失,并打印出来。 nn.NLLLoss2d()在处理多类别分类问题时非常有用,特别是在图像识别和分割等任务中。由于它对每个像素进行了独立处理,因此特别适合处理二维数据。在实际应用中,结合反向传播算法,可以通过优化网络权重来最小化这个损失,从而提高模型的预测准确性。