UNet3+ pytorch实现
时间: 2023-07-13 09:34:12 浏览: 354
以下是一个简单的UNet 3+的PyTorch实现,仅供参考:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class DoubleConv(nn.Module):
def __init__(self, in_channels, out_channels, mid_channels=None):
super().__init__()
if not mid_channels:
mid_channels = out_channels
self.double_conv = nn.Sequential(
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(mid_channels),
nn.ReLU(inplace=True),
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
return self.double_conv(x)
class Down(nn.Module):
def __init__(self, in_channels, out_channels):
super().__init__()
self.maxpool_conv = nn.Sequential(
nn.MaxPool2d(2),
DoubleConv(in_channels, out_channels)
)
def forward(self, x):
return self.maxpool_conv(x)
class Up(nn.Module):
def __init__(self, in_channels, out_channels, bilinear=True):
super().__init__()
if bilinear:
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
else:
self.up = nn.ConvTranspose2d(in_channels // 2, in_channels // 2, kernel_size=2, stride=2)
self.conv = DoubleConv(in_channels, out_channels)
def forward(self, x1, x2):
x1 = self.up(x1)
# input is CHW
diffY = x2.size()[2] - x1.size()[2]
diffX = x2.size()[3] - x1.size()[3]
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
diffY // 2, diffY - diffY // 2])
x = torch.cat([x2, x1], dim=1)
return self.conv(x)
class AttentionBlock(nn.Module):
def __init__(self, F_g, F_l, F_int):
super(AttentionBlock, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi
class UNet3Plus(nn.Module):
def __init__(self, in_channels=3, out_channels=1, bilinear=True):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.bilinear = bilinear
self.down1 = Down(in_channels, 32)
self.att1 = AttentionBlock(F_g=32, F_l=32, F_int=16)
self.down2 = Down(32, 64)
self.att2 = AttentionBlock(F_g=64, F_l=64, F_int=32)
self.down3 = Down(64, 128)
self.att3 = AttentionBlock(F_g=128, F_l=128, F_int=64)
self.down4 = Down(128, 256)
self.att4 = AttentionBlock(F_g=256, F_l=256, F_int=128)
self.center = DoubleConv(256, 512)
self.att5 = AttentionBlock(F_g=512, F_l=512, F_int=256)
self.up4 = Up(512, 256, self.bilinear)
self.att6 = AttentionBlock(F_g=256, F_l=256, F_int=128)
self.up3 = Up(256, 128, self.bilinear)
self.att7 = AttentionBlock(F_g=128, F_l=128, F_int=64)
self.up2 = Up(128, 64, self.bilinear)
self.att8 = AttentionBlock(F_g=64, F_l=64, F_int=32)
self.up1 = Up(64, 32, self.bilinear)
self.outc = nn.Conv2d(32, out_channels, kernel_size=1)
def forward(self, x):
x1 = self.down1(x)
x2 = self.down2(x1)
x3 = self.down3(x2)
x4 = self.down4(x3)
center = self.center(x4)
center = self.att5(g=center, x=center)
x4 = self.att4(g=center, x=x4)
x3 = self.att3(g=x4, x=x3)
x2 = self.att2(g=x3, x=x2)
x1 = self.att1(g=x2, x=x1)
x = self.up4(center, x4)
x = self.att6(g=x, x=center)
x = self.up3(x, x3)
x = self.att7(g=x, x=x3)
x = self.up2(x, x2)
x = self.att8(g=x, x=x2)
x = self.up1(x, x1)
logits = self.outc(x)
return logits
```
在这个实现中,我们首先定义了一个双卷积块(DoubleConv)、下采样块(Down)和上采样块(Up)。然后定义了一个注意力块(AttentionBlock)来实现UNet 3+中的跨通道注意力机制。最后,我们将所有块组合在一起以定义UNet 3+网络。
阅读全文