Spatial Reduction Attention Block (SRAB)代码
时间: 2024-03-01 17:51:58 浏览: 132
下面是使用 PyTorch 实现 Spatial Reduction Attention Block (SRAB) 的代码,仅供参考:
``` python
import torch
import torch.nn as nn
import torch.nn.functional as F
class ChannelAttention(nn.Module):
def __init__(self, in_channels, reduction_ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(in_channels, in_channels // reduction_ratio, 1, bias=False)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(in_channels // reduction_ratio, in_channels, 1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
y = self.avg_pool(x)
y = self.fc1(y)
y = self.relu(y)
y = self.fc2(y)
y = self.sigmoid(y)
return y
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_mask = torch.mean(x, dim=1, keepdim=True)
max_mask, _ = torch.max(x, dim=1, keepdim=True)
mask = torch.cat([avg_mask, max_mask], dim=1)
mask = self.conv1(mask)
mask = self.sigmoid(mask)
return mask
class SRAB(nn.Module):
def __init__(self, in_channels, reduction_ratio=16, kernel_size=7):
super(SRAB, self).__init__()
self.ca = ChannelAttention(in_channels, reduction_ratio)
self.sa = SpatialAttention(kernel_size)
def forward(self, x):
y = x * self.ca(x)
y = y * self.sa(y)
y = y + x
return y
```
其中,`ChannelAttention` 和 `SpatialAttention` 分别表示通道注意力和空间注意力计算的部分,`SRAB` 表示整个 SRAB 模块的实现。在 `SRAB` 模块的正向传播过程中,首先通过通道注意力计算得到通道注意力加权特征图,然后通过空间注意力计算得到空间注意力加权特征图,最后将空间注意力加权特征图与输入特征图进行残差连接,得到最终输出特征图。
阅读全文