CBAM在pytorch实现
时间: 2024-03-23 14:06:51 浏览: 65
pytorch_CAM实现
在PyTorch中,CBAM(Convolutional Block Attention Module)可以通过自定义模型类来实现。下面是一个简单的CBAM模块的示例代码:
```python
import torch
import torch.nn as nn
class CBAM(nn.Module):
def __init__(self, channels, reduction=16):
super(CBAM, self).__init__()
# Channel Attention
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0)
# Spatial Attention
self.conv3 = nn.Conv2d(2, 1, kernel_size=7, padding=3)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
# Channel Attention
avg_out = self.fc2(self.relu(self.fc1(self.avg_pool(x))))
max_out = self.fc2(self.relu(self.fc1(self.max_pool(x))))
channel_attention = self.sigmoid(avg_out + max_out)
# Spatial Attention
spatial_avg = torch.mean(channel_attention, dim=1, keepdim=True)
spatial_max, _ = torch.max(channel_attention, dim=1, keepdim=True)
spatial_attention = self.sigmoid(self.conv3(torch.cat([spatial_avg, spatial_max], dim=1)))
# Apply attention to the input
x = x * spatial_attention * channel_attention
return x
```
这个示例代码中定义了一个CBAM模块,包含了通道注意力和空间注意力两部分。通道注意力使用了全局平均池化和全局最大池化来提取通道信息,然后通过全连接层和ReLU激活函数来学习通道权重,最后使用Sigmoid函数对权重进行归一化。空间注意力使用了平均池化和最大池化来提取空间信息,然后通过卷积层和Sigmoid函数来生成一个注意力图。最后,将通道注意力和空间注意力相乘,得到最终的特征图。
你可以将这个CBAM模块加入到你的模型中,作为一个子模块来使用。
阅读全文