class ASPP(nn.Module) def init(self, dim_in, dim_out, rate=1, bn_mom=0.1) super(ASPP, self).init() self.branch1 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch2 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=4 rate, dilation=4 rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch3 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=8 rate, dilation=8 rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch4 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=12 rate, dilation=12 rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch5 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=16 rate, dilation=16 rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) self.branch6 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=20 rate, dilation=20 rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True) ) self.branch7 = nn.Sequential( nn.Conv2d(dim_in, dim_out, 3, 1, padding=24 rate, dilation=24 rate, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True) ) self.branch8_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True) self.branch8_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom) self.branch8_relu = nn.ReLU(inplace=True) self.conv_cat = nn.Sequential( nn.Conv2d(dim_out 8, dim_out, 1, 1, padding=0, bias=True), nn.BatchNorm2d(dim_out, momentum=bn_mom), nn.ReLU(inplace=True), ) def forward(self, x) [b, c, row, col] = x.size() conv1x1 = self.branch1(x) conv3x3_1 = self.branch2(x) conv3x3_2 = self.branch3(x) conv3x3_3 = self.branch4(x) conv3x3_4 = self.branch5(x) conv3x3_5 = self.branch6(x) conv3x3_6 = self.branch7(x) global_feature = torch.mean(x, 2, True) global_feature = torch.mean(global_feature, 3, True) global_feature = self.branch8_conv(global_feature) global_feature = self.branch8_bn(global_feature) global_feature = self.branch8_relu(global_feature) global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True) feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, conv3x3_4, conv3x3_5, conv3x3_6, global_feature], dim=1) result = self.conv_cat(feature_cat) return result用1×3卷积和3×1卷积代替这段代码的3×3卷积
时间: 2023-09-16 16:14:29 浏览: 126
可以将原来的3x3卷积分解成1x3卷积和3x1卷积的组合,代替代码中的3x3卷积。修改后的代码如下:
```
class ASPP(nn.Module):
def __init__(self, dim_in, dim_out, rate=1, bn_mom=0.1):
super(ASPP, self).__init__()
self.branch1 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
self.branch2 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out, dim_out, (1, 3), 1, padding=(0, 1)*rate, dilation=rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out, dim_out, (3, 1), 1, padding=(1, 0)*rate, dilation=rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True)
)
self.branch3 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out, dim_out, (1, 3), 1, padding=(0, 2)*rate, dilation=2*rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out, dim_out, (3, 1), 1, padding=(2, 0)*rate, dilation=2*rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True)
)
self.branch4 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out, dim_out, (1, 3), 1, padding=(0, 3)*rate, dilation=3*rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out, dim_out, (3, 1), 1, padding=(3, 0)*rate, dilation=3*rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True)
)
self.branch5 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out, dim_out, (1, 3), 1, padding=(0, 4)*rate, dilation=4*rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out, dim_out, (3, 1), 1, padding=(4, 0)*rate, dilation=4*rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True)
)
self.branch6 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out, dim_out, (1, 3), 1, padding=(0, 5)*rate, dilation=5*rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out, dim_out, (3, 1), 1, padding=(5, 0)*rate, dilation=5*rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True)
)
self.branch7 = nn.Sequential(
nn.Conv2d(dim_in, dim_out, 1, 1, padding=0, dilation=rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out, dim_out, (1, 3), 1, padding=(0, 6)*rate, dilation=6*rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
nn.Conv2d(dim_out, dim_out, (3, 1), 1, padding=(6, 0)*rate, dilation=6*rate, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True)
)
self.branch8_conv = nn.Conv2d(dim_in, dim_out, 1, 1, 0, bias=True)
self.branch8_bn = nn.BatchNorm2d(dim_out, momentum=bn_mom)
self.branch8_relu = nn.ReLU(inplace=True)
self.conv_cat = nn.Sequential(
nn.Conv2d(dim_out * 8, dim_out, 1, 1, padding=0, bias=True),
nn.BatchNorm2d(dim_out, momentum=bn_mom),
nn.ReLU(inplace=True),
)
def forward(self, x):
[b, c, row, col] = x.size()
conv1x1 = self.branch1(x)
conv3x3_1 = self.branch2(x)
conv3x3_2 = self.branch3(x)
conv3x3_3 = self.branch4(x)
conv3x3_4 = self.branch5(x)
conv3x3_5 = self.branch6(x)
conv3x3_6 = self.branch7(x)
global_feature = torch.mean(x, 2, True)
global_feature = torch.mean(global_feature, 3, True)
global_feature = self.branch8_conv(global_feature)
global_feature = self.branch8_bn(global_feature)
global_feature = self.branch8_relu(global_feature)
global_feature = F.interpolate(global_feature, (row, col), None, 'bilinear', True)
feature_cat = torch.cat([conv1x1, conv3x3_1, conv3x3_2, conv3x3_3, conv3x3_4, conv3x3_5, conv3x3_6, global_feature], dim=1)
result = self.conv_cat(feature_cat)
return result
```
阅读全文