unet和attention unet都是两层 给我全部完整的代码 我直接复制粘贴过来 一次成功
时间: 2024-11-22 10:48:51 浏览: 27
以下是 `UNet` 和 `AttentionUNet` 模型的完整代码,包括所有必要的导入和定义:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class UNet(nn.Module):
def __init__(self):
super(UNet, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.ReLU(),
nn.ConvTranspose2d(64, 3, kernel_size=2, stride=2),
nn.Sigmoid(),
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
class AttentionBlock(nn.Module):
def __init__(self, F_g, F_l, F_int):
super(AttentionBlock, self).__init__()
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
self.psi = nn.Sequential(
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(1),
nn.Sigmoid()
)
self.relu = nn.ReLU(inplace=True)
def forward(self, g, x):
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.relu(g1 + x1)
psi = self.psi(psi)
return x * psi
class AttentionUNet(nn.Module):
def __init__(self):
super(AttentionUNet, self).__init__()
self.encoder1 = self.conv_block(3, 64)
self.encoder2 = self.conv_block(64, 128)
self.encoder3 = self.conv_block(128, 256)
self.encoder4 = self.conv_block(256, 512)
self.bottleneck = self.conv_block(512, 1024)
self.upconv4 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2)
self.att4 = AttentionBlock(F_g=512, F_l=512, F_int=256)
self.decoder4 = self.conv_block(1024, 512)
self.upconv3 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2)
self.att3 = AttentionBlock(F_g=256, F_l=256, F_int=128)
self.decoder3 = self.conv_block(512, 256)
self.upconv2 = nn.ConvTranspose2d(256, 128, kernel_size=2, stride=2)
self.att2 = AttentionBlock(F_g=128, F_l=128, F_int=64)
self.decoder2 = self.conv_block(256, 128)
self.upconv1 = nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
self.att1 = AttentionBlock(F_g=64, F_l=64, F_int=32)
self.decoder1 = self.conv_block(128, 64)
self.final_conv = nn.Conv2d(64, 3, kernel_size=1, stride=1, padding=0)
self.sigmoid = nn.Sigmoid()
def conv_block(self, in_channels, out_channels):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)
)
def forward(self, x):
e1 = self.encoder1(x)
e2 = self.encoder2(F.max_pool2d(e1, 2))
e3 = self.encoder3(F.max_pool2d(e2, 2))
e4 = self.encoder4(F.max_pool2d(e3, 2))
b = self.bottleneck(F.max_pool2d(e4, 2))
d4 = self.upconv4(b)
e4 = self.att4(g=d4, x=e4)
d4 = torch.cat((e4, d4), dim=1)
d4 = self.decoder4(d4)
d3 = self.upconv3(d4)
e3 = self.att3(g=d3, x=e3)
d3 = torch.cat((e3, d3), dim=1)
d3 = self.decoder3(d3)
d2 = self.upconv2(d3)
e2 = self.att2(g=d2, x=e2)
d2 = torch.cat((e2, d2), dim=1)
d2 = self.decoder2(d2)
d1 = self.upconv1(d2)
e1 = self.att1(g=d1, x=e1)
d1 = torch.cat((e1, d1), dim=1)
d1 = self.decoder1(d1)
out = self.final_conv(d1)
out = self.sigmoid(out)
return out
```
### 使用说明
1. **导入依赖**:确保你已经安装了 `torch` 和其他必要的库。
2. **实例化模型**:你可以通过以下方式实例化 `UNet` 或 `AttentionUNet` 模型:
```python
unet = UNet()
attention_unet = AttentionUNet()
```
3. **前向传播**:使用模型进行前向传播时,输入张量的形状应为 `(batch_size, 3, height, width)`,例如:
```python
input_tensor = torch.randn(1, 3, 256, 256)
output = unet(input_tensor)
```
希望这些代码对你有帮助!如果有任何问题,请随时提问。
阅读全文