transunet引入CBAM
时间: 2025-01-02 21:31:10 浏览: 11
### 如何在 TransUNet 中集成 CBAM 模块
为了提升 TransUNet 的性能,可以在模型的关键层中引入 Convolutional Block Attention Module (CBAM),这是一种轻量级的通用注意力机制。具体来说,CBAM 可以帮助改进特征表示的质量,特别是在处理复杂医学影像数据时。
#### 1. CBAM 工作原理概述
CBAM 是一种简单而有效的注意力模块,它能够在不显著增加计算成本的情况下改善卷积神经网络的表现。该模块通过两个独立的操作来生成最终的注意力图:
- **通道注意力**:通过对输入特征图的不同通道进行加权,突出重要的通道信息。
- **空间注意力**:聚焦于特征图的空间位置,增强重要区域的信息[^2]。
```python
import torch.nn as nn
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
self.relu1 = nn.ReLU()
self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
out = avg_out + max_out
return self.sigmoid(out)
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
return self.sigmoid(x)
class CBAM(nn.Module):
def __init__(self, gate_channels, reduction_ratio=16, no_spatial=False):
super(CBAM, self).__init__()
self.channel_attention = ChannelAttention(gate_channels, reduction_ratio)
self.spatial_attention = None if no_spatial else SpatialAttention()
def forward(self, x):
x_out = self.channel_attention(x) * x
if self.spatial_attention is not None:
x_out = self.spatial_attention(x_out) * x_out
return x_out
```
#### 2. 将 CBAM 集成到 TransUNet
要在 TransUNet 架构中加入 CBAM,可以选择将其放置在网络中的特定层次上,比如编码器的最后一层之后或解码器的第一层之前。这样可以确保经过初步变换后的特征能够得到更精细的关注度调整。
```python
from monai.networks.nets import TransUNet
def add_cbam_to_transunet(transunet_model, cbam_position='decoder_start'):
"""
向现有的 TransUNet 模型添加 CBAM 层
参数:
transunet_model: 原始的 TransUNet 实例.
cbam_position: 插入 CBAM 的位置 ('encoder_end', 'decoder_start').
返回:
修改后的 TransUNet 模型实例.
"""
class TransUNetWithCBAM(TransUNet):
def __init__(self, original_model, position):
super().__init__(
res_block=original_model.res_block,
img_dim=original_model.img_dim,
in_channels=original_model.in_channels,
out_channels=original_model.out_channels,
feature_size=original_model.feature_size,
num_layers=original_model.num_layers,
hidden_size=original_model.hidden_size,
mlp_dim=original_model.mlp_dim,
num_heads=original_model.num_heads,
pos_embed=original_model.pos_embed,
norm_name=original_model.norm_name,
conv_block=original_model.conv_block,
dropout_rate=original_model.dropout_rate,
)
# 添加 CBAM 到指定的位置
self.cbam_module = CBAM(
gate_channels=self.encoder[-1].out_channels if position=='encoder_end' else \
self.decoder[0][0].in_channels
)
self.position = position
def forward(self, x):
encoder_outputs = []
for layer in self.encoder[:-1]:
x = layer(x)
encoder_outputs.append(x)
if self.position == 'encoder_end':
x = self.encoder[-1](x)
x = self.cbam_module(x)
elif self.position == 'decoder_start':
x = self.encoder[-1](x)
decoder_input = [x] + list(reversed(encoder_outputs))
x = self.cbam_module(decoder_input.pop())
for i, block in enumerate(self.decoder):
skip_connection = decoder_input[i] if i < len(decoder_input) else None
if isinstance(block, tuple): # 处理可能存在的跳跃连接
x = block[0](torch.cat((skip_connection, x), dim=1)) if skip_connection is not None else block[0](x)
else:
x = block(x)
output = self.final_conv(x)
return output
modified_model = TransUNetWithCBAM(original_model=transunet_model, position=cbam_position)
return modified_model
```
此代码片段展示了如何创建一个新的 `TransUNet` 子类,并在其内部定义了一个额外的方法用于插入 CBAM 模块。这使得可以根据需求灵活选择是在编码阶段结束还是解码阶段开始处应用 CBAM。
阅读全文