class ASPP(nn.Module) def init(self, dim_in, dim_out, rate=1, bn_mom=0.1) super(ASPP, self).init() self.branch1 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch2 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=4 rate, dilation=4 rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch3 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=8 rate, dilation=8 rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch4 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=12 rate, dilation=12 rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch5 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=16 rate, dilation=16 rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch6 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=20 rate, dilation=20 rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True) ) self.branch7 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=24 rate, dilation=24 rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True) ) self.branch8_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True) self.branch8_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom) self.branch8_relu = nn.ReLU(inplace=True) self.conv_cat = nn.Sequential( nn.Conv2d(dim_out 8, dim_out, 1, 1, padding=0, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) def forward(self, x) [b, c, row, col] = x.size() conv1x1 = self.branch1(x) conv3x3_1 = self.branch2(x) conv3x3_2 = self.branch3(x) conv3x3_3 = self.branch4(x) conv3x3_4 = self.branch5(x) conv3x3_5 = self.branch6(x) conv3x3_6 = self.branch7(x) global_feature = torch.mean(x, 2, True) global_feature = torch.mean(global_feature, 3, True) global_feature = self.branch8_conv(global_feature) global_feature = self.branch8_bn(global_feature) global_feature = self.branch8_relu(global_feature) global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True) feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, conv3x3_4, conv3x3_5, conv3x3_6, global_feature], dim=1) result = self.conv_cat(feature_cat) return result用深度可分离卷积代替这段代码的3×3卷积
时间: 2023-09-29 15:08:09 浏览: 61
可以使用深度可分离卷积来代替这段代码中的所有3×3卷积,具体替换方式如下:
将代码中的 nn.Conv2d(dim_in, dim_out, 3, 1, padding=4*rate, dilation=4*rate, bias=True) 替换为 nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=4*rate, dilation=rate, groups=dim_in, bias=True)。
将代码中的 nn.Conv2d(dim_in, dim_out, 3, 1, padding=8*rate, dilation=8*rate, bias=True) 替换为 nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=8*rate, dilation=rate, groups=dim_in, bias=True)。
将代码中的 nn.Conv2d(dim_in, dim_out, 3, 1, padding=12*rate, dilation=12*rate, bias=True) 替换为 nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=12*rate, dilation=rate, groups=dim_in, bias=True)。
将代码中的 nn.Conv2d(dim_in, dim_out, 3, 1, padding=16*rate, dilation=16*rate, bias=True) 替换为 nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=16*rate, dilation=rate, groups=dim_in, bias=True)。
将代码中的 nn.Conv2d(dim_in, dim_out, 3, 1, padding=20*rate, dilation=20*rate, bias=True) 替换为 nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=20*rate, dilation=rate, groups=dim_in, bias=True)。
将代码中的 nn.Conv2d(dim_in, dim_out, 3, 1, padding=24*rate, dilation=24*rate, bias=True) 替换为 nn.Conv2d(dim_in, dim_out, kernel_size=3, stride=1, padding=24*rate, dilation=rate, groups=dim_in, bias=True)。
这样做可以减少模型的参数量,同时也能够提高模型的计算效率。
阅读全文