class NormalLoss(nn.Module): def __init__(self,ignore_lb=255, *args, **kwargs): super( NormalLoss, self).__init__() self.criteria = nn.CrossEntropyLoss(ignore_index=ignore_lb, reduction='none') def forward(self, logits, labels): N, C, H, W = logits.size() loss = self.criteria(logits, labels) return torch.mean(loss) class Fusionloss(nn.Module): def __init__(self): super(Fusionloss, self).__init__() self.sobelconv=Sobelxy() def forward(self,image_vis,image_ir,labels,generate_img,i): image_y=image_vis[:,:1,:,:] x_in_max=torch.max(image_y,image_ir) loss_in=F.l1_loss(x_in_max,generate_img) y_grad=self.sobelconv(image_y) ir_grad=self.sobelconv(image_ir) generate_img_grad=self.sobelconv(generate_img) x_grad_joint=torch.max(y_grad,ir_grad) loss_grad=F.l1_loss(x_grad_joint,generate_img_grad) loss_total=loss_in+10*loss_grad return loss_total,loss_in,loss_grad
时间: 2024-04-18 20:23:51 浏览: 198
这段代码定义了两个损失函数类:NormalLoss和Floss。
NormalLoss是一个普通的损失函数类,继承自nn.Module。构造函数中包含了一个可选参数ignore_lb,默认值为255。在构造函数中,使用nn.CrossEntropyLoss作为损失函数的标准,设置ignore_index为ignore_lb,reduction为'none',这样可以得到每个样本的损失值。在前向传播方法forward中,计算logits和labels之间的交叉熵损失loss,并取平均值返回。
Fusionloss是一个融合损失函数类,继承自nn.Module。构造函数中初始化了一个Sobelxy模块(未给出代码),该模块用于计算图像的梯度。在前向传播方法forward中,接受image_vis、image_ir、labels、generate_img和i作为输入。首先从image_vis中提取灰度通道image_y,然后计算image_y和image_ir的最大值x_in_max,并使用F.l1_loss计算其与generate_img之间的L1损失loss_in。接下来,分别计算image_y、image_ir和generate_img的梯度,并取最大值得到x_grad_joint。再次使用F.l1_loss计算x_grad_joint和generate_img_grad之间的L1损失loss_grad。最后,将loss_in和10倍的loss_grad相加得到总的损失loss_total,并返回。
整体来看,这段代码定义了两个损失函数类,NormalLoss用于计算交叉熵损失,Fusionloss用于计算融合损失。