class BCPLoss(torch.nn.Module): def __init__(self, eps=1e-6): super(BCPLoss, self).__init__() self.eps = eps def forward(self, x, y): # Calculate BCP feature bcp_x = self._bcp(x) bcp_y = self._bcp(y) # Calculate BCP Loss loss = torch.mean((bcp_x - bcp_y) ** 2) loss = torch.clamp(loss, min=self.eps) return loss def _bcp(self, x): # Calculate bright channel feature bcp = torch.mean(torch.abs(x), dim=(2, 3), keepdim=True) bcp = F.relu(bcp) return bcp
时间: 2024-02-02 11:03:32 浏览: 174
torch_sparse-0.6.15+pt112cu116-cp310-cp310-win_amd64.whl.zip
这段代码是用来实现 Bright Channel Prior Loss (BCP Loss) 的,是一种图像去雾的损失函数。该损失函数的计算过程分为两步,第一步是计算输入图片 x 和目标图片 y 的亮度通道,即 Bright Channel Feature。第二步是计算两者之间的差异,即 BCP Loss。其中,eps 表示一个极小的常数,用来防止除数为0的情况发生。
阅读全文