残差模块代码实现,瓶颈残差模块代码实现,倒残差代码实现,沙漏残差代码实现
时间: 2023-05-26 13:06:13 浏览: 59
由于不清楚具体是哪种残差模块,以下为常见的三种残差模块代码实现:
1. 基础残差模块:
```
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.stride = stride
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels))
def forward(self, x):
residual = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(residual)
out = self.relu(out)
return out
```
2. 瓶颈残差模块:
```
class BottleneckBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv3 = nn.Conv2d(out_channels, out_channels * 4, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels * 4)
self.relu = nn.ReLU(inplace=True)
self.stride = stride
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels * 4:
self.shortcut = nn.Sequential(nn.Conv2d(in_channels, out_channels * 4, kernel_size=1, stride=stride, bias=False), nn.BatchNorm2d(out_channels * 4))
def forward(self, x):
residual = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.relu(self.bn2(self.conv2(out)))
out = self.bn3(self.conv3(out))
out += self.shortcut(residual)
out = self.relu(out)
return out
```
3. 沙漏残差模块:
```
class HourglassBlock(nn.Module):
def __init__(self, n_channels):
super().__init__()
self.downsample = nn.Sequential(nn.Conv2d(n_channels, n_channels * 2, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(n_channels * 2), nn.ReLU(inplace=True))
self.upsample = nn.Sequential(nn.ConvTranspose2d(n_channels * 2, n_channels, kernel_size=3, stride=2, padding=1, output_padding=1), nn.BatchNorm2d(n_channels), nn.ReLU(inplace=True))
self.res1 = ResidualBlock(n_channels * 2, n_channels * 2)
self.res2 = ResidualBlock(n_channels, n_channels)
def forward(self, x):
out1 = self.downsample(x)
out1 = self.res1(out1)
out2 = self.upsample(out1)
out2 = self.res2(out2)
return out2
```
4. 倒残差模块:
由于倒残差模块并不常见,所以以下代码实现仅供参考:
```
class InvResBlock(nn.Module):
def __init__(self, channels):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, kernel_size=1)
self.norm1 = nn.BatchNorm2d(channels)
self.conv2 = nn.ConvTranspose2d(channels, channels, kernel_size=3, stride=2, padding=1, output_padding=1)
self.norm2 = nn.BatchNorm2d(channels)
self.conv3 = nn.ConvTranspose2d(channels, channels, kernel_size=3, stride=1, padding=1)
self.norm3 = nn.BatchNorm2d(channels)
def forward(self, x):
out = F.relu(self.norm1(self.conv1(x)))
out = F.relu(self.norm2(self.conv2(out)))
out = self.norm3(self.conv3(out))
return out + x
```