pytorch搭建unet3+网络加上注意力机制
时间: 2023-05-26 14:04:12 浏览: 287
1. 首先导入所需的PyTorch模块
```
import torch
import torch.nn as nn
import torch.nn.functional as F
```
2. 定义注意力模块
可以从原始的U-Net卷积层中获取不同维度的信息,从而为每个像素提供更准确的定位。
```
class AttentionBlock(nn.Module):
def __init__(self, in_channels, gate_channels, use_res=True):
super(AttentionBlock, self).__init__()
self.use_res = use_res
self.in_channels = in_channels
self.W = nn.Sequential(
nn.Conv2d(in_channels, gate_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(gate_channels),
nn.ReLU(inplace=True),
nn.Conv2d(gate_channels, in_channels, kernel_size=1, bias=False),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True)
)
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
assert x.size()[1] == self.in_channels
Wx = self.W(x)
if self.use_res:
out = x + self.gamma * Wx
else:
out = Wx
return out
```
3. 定义U-Net网络结构
```
class UNet(nn.Module):
def __init__(self, in_channels=3, out_channels=1):
super().__init__()
# Encoder部分
self.enc1 = nn.Sequential(
nn.Conv2d(in_channels, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.enc2 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True)
)
self.enc3 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True)
)
self.enc4 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True)
)
self.enc5 = nn.Sequential(
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(512, 1024, kernel_size=3, padding=1),
nn.BatchNorm2d(1024),
nn.ReLU(inplace=True),
nn.Conv2d(1024, 1024, kernel_size=3, padding=1),
nn.BatchNorm2d(1024),
nn.ReLU(inplace=True)
)
# Decoder部分
self.dec5 = nn.Sequential(
nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True)
)
self.dec4 = nn.Sequential(
nn.ConvTranspose2d(1024, 256, kernel_size=2, stride=2),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True)
)
self.dec3 = nn.Sequential(
nn.ConvTranspose2d(512, 128, kernel_size=2, stride=2),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True)
)
self.dec2 = nn.Sequential(
nn.ConvTranspose2d(256, 64, kernel_size=2, stride=2),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.dec1 = nn.Sequential(
nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
# 辅助注意力模块
self.att1 = AttentionBlock(64, 64)
self.att2 = AttentionBlock(128, 64)
self.att3 = AttentionBlock(256, 64)
self.att4 = AttentionBlock(512, 64)
# 最后一层卷积层(输出层)
self.out = nn.Conv2d(64, out_channels, kernel_size=1)
def forward(self, x):
# Encoder部分
enc1 = self.enc1(x)
enc2 = self.enc2(F.max_pool2d(enc1, kernel_size=2, stride=2))
enc3 = self.enc3(F.max_pool2d(enc2, kernel_size=2, stride=2))
enc4 = self.enc4(F.max_pool2d(enc3, kernel_size=2, stride=2))
enc5 = self.enc5(F.max_pool2d(enc4, kernel_size=2, stride=2))
# Decoder部分
dec5 = torch.cat((enc4, self.dec5(enc5)), dim=1)
dec4 = self.att4(torch.cat((enc3, self.dec4(dec5)), dim=1))
dec3 = self.att3(torch.cat((enc2, self.dec3(dec4)), dim=1))
dec2 = self.att2(torch.cat((enc1, self.dec2(dec3)), dim=1))
dec1 = self.att1(self.dec1(dec2))
out = self.out(dec1)
return out
```
4. 实例化模型并开始训练
可以使用常规的训练和测试代码来训练和测试新的U-Net网络结构,以便于检测和分割各种目标。
阅读全文
相关推荐

















