class SGABlock(nn.Module): def __init__(self, channels=32, refine=False): super(SGABlock, self).__init__() self.refine = refine if self.refine: self.bn_relu = nn.Sequential(BatchNorm3d(channels), nn.ReLU(inplace=True)) self.conv_refine = BasicConv(channels, channels, is_3d=True, kernel_size=3, padding=1, relu=False) # self.conv_refine1 = BasicConv(8, 8, is_3d=True, kernel_size=1, padding=1) else: self.bn = BatchNorm3d(channels) self.SGA=SGA() self.relu = nn.ReLU(inplace=True) def forward(self, x, g): rem = x k1, k2, k3, k4 = torch.split(g, (x.size()[1]*5, x.size()[1]*5, x.size()[1]*5, x.size()[1]*5), 1) k1 = F.normalize(k1.view(x.size()[0], x.size()[1], 5, x.size()[3], x.size()[4]), p=1, dim=2) k2 = F.normalize(k2.view(x.size()[0], x.size()[1], 5, x.size()[3], x.size()[4]), p=1, dim=2) k3 = F.normalize(k3.view(x.size()[0], x.size()[1], 5, x.size()[3], x.size()[4]), p=1, dim=2) k4 = F.normalize(k4.view(x.size()[0], x.size()[1], 5, x.size()[3], x.size()[4]), p=1, dim=2) x = self.SGA(x, k1, k2, k3, k4) if self.refine: x = self.bn_relu(x) x = self.conv_refine(x) else: x = self.bn(x) assert(x.size() == rem.size()) x += rem return self.relu(x) # return self.bn_relu(x)
时间: 2023-11-22 21:05:31 浏览: 74
finite2_refine.rar_finite2_refine_介质 静电场_有限差分 速度_电场分布_静电场计算
这是一个使用了 SGA(Spatial Group-wise Aggregation)的神经网络模块,用于图像处理任务。SGA 是一种空间注意力机制,能够对输入特征图进行加权聚合,从而更好地捕捉不同位置的上下文信息。该模块接受两个输入:x 表示输入特征图,g 表示用于计算空间注意力权重的向量。模块首先将 g 向量按通道分成四份,再对每份向量进行 L1 归一化,最后将这四个归一化后的向量作为空间注意力的输入,输出加权聚合后的特征图。如果 refine 参数为 True,则模块会在空间注意力之后再进行一些卷积操作。
阅读全文