Deep Residual U-Net
时间: 2024-01-19 10:03:22 浏览: 87
Deep Residual U-Net是一种基于U-Net和ResNet的图像分割网络。它采用了U-Net的编码器-解码器结构,同时在每个解码器块中使用了ResNet的残差块(DR块)来提高特征提取能力。DR块通过引入SE(Squeeze-and-Excitation)机制来增强编码器的全局特征提取能力,同时使用1×1卷积来改变特征图的维度,以确保3×3卷积滤波器不受前一层的影响。此外,为了避免网络太深的影响,在两组Conv 1×1-Conv 3×3操作之间引入了一条捷径,允许网络跳过可能导致性能下降的层,并将原始特征转移到更深的层。Deep Residual U-Net在多个图像分割任务中都取得了优秀的性能。
以下是Deep Residual U-Net的编码器-解码器结构示意图:
```python
import torch.nn as nn
class DRBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(DRBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=1)
self.relu = nn.ReLU(inplace=True)
self.se = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(out_channels, out_channels // 16, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels // 16, out_channels, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
identity = x
out = self.conv1(x)
out = self.relu(out)
out = self.conv2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.se(out) * out
out += identity
out = self.relu(out)
return out
class DRUNet(nn.Module):
def __init__(self, in_channels, out_channels, init_features=32):
super(DRUNet, self).__init__()
features = init_features
self.encoder1 = nn.Sequential(
nn.Conv2d(in_channels, features, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(features, features, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder2 = nn.Sequential(
DRBlock(features, features * 2),
nn.Conv2d(features * 2, features * 2, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
DRBlock(features * 2, features * 2),
nn.Conv2d(features * 2, features * 2, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder3 = nn.Sequential(
DRBlock(features * 2, features * 4),
nn.Conv2d(features * 4, features * 4, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
DRBlock(features * 4, features * 4),
nn.Conv2d(features * 4, features * 4, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.encoder4 = nn.Sequential(
DRBlock(features * 4, features * 8),
nn.Conv2d(features * 8, features * 8, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
DRBlock(features * 8, features * 8),
nn.Conv2d(features * 8, features * 8, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.pool4 = nn.MaxPool2d(kernel_size=2, stride=2)
self.bottleneck = nn.Sequential(
DRBlock(features * 8, features * 16),
nn.Conv2d(features * 16, features * 16, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
DRBlock(features * 16, features * 16),
nn.Conv2d(features * 16, features * 16, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.upconv4 = nn.ConvTranspose2d(features * 16, features * 8, kernel_size=2, stride=2)
self.decoder4 = nn.Sequential(
DRBlock(features * 16, features * 8),
nn.Conv2d(features * 8, features * 8, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
DRBlock(features * 8, features * 8),
nn.Conv2d(features * 8, features * 8, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.upconv3 = nn.ConvTranspose2d(features * 8, features * 4, kernel_size=2, stride=2)
self.decoder3 = nn.Sequential(
DRBlock(features * 8, features * 4),
nn.Conv2d(features * 4, features * 4, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
DRBlock(features * 4, features * 4),
nn.Conv2d(features * 4, features * 4, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.upconv2 = nn.ConvTranspose2d(features * 4, features * 2, kernel_size=2, stride=2)
self.decoder2 = nn.Sequential(
DRBlock(features * 4, features * 2),
nn.Conv2d(features * 2, features * 2, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
DRBlock(features * 2, features * 2),
nn.Conv2d(features * 2, features * 2, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.upconv1 = nn.ConvTranspose2d(features * 2, features, kernel_size=2, stride=2)
self.decoder1 = nn.Sequential(
DRBlock(features * 2, features),
nn.Conv2d(features, features, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
DRBlock(features, features),
nn.Conv2d(features, features, kernel_size=3, padding=1),
nn.ReLU(inplace=True)
)
self.conv = nn.Conv2d(features, out_channels, kernel_size=1)
def forward(self, x):
enc1 = self.encoder1(x)
enc2 = self.encoder2(self.pool1(enc1))
enc3 = self.encoder3(self.pool2(enc2))
enc4 = self.encoder4(self.pool3(enc3))
bottleneck = self.bottleneck(self.pool4(enc4))
dec4 = self.upconv4(bottleneck)
dec4 = torch.cat((enc4, dec4), dim=1)
dec4 = self.decoder4(dec4)
dec3 = self.upconv3(dec4)
dec3 = torch.cat((enc3, dec3), dim=1)
dec3 = self.decoder3(dec3)
dec2 = self.upconv2(dec3)
dec2 = torch.cat((enc2, dec2), dim=1)
dec2 = self.decoder2(dec2)
dec1 = self.upconv1(dec2)
dec1 = torch.cat((enc1, dec1), dim=1)
dec1 = self.decoder1(dec1)
return self.conv(dec1)
```
阅读全文