空间注意力模块的实现代码
时间: 2023-10-26 14:14:42 浏览: 40
空间注意力模块是一个常用的自注意力机制,可以用于图像分类、目标检测、语义分割等任务中。下面是一个简单的空间注意力模块的实现代码,供你参考:
```python
import torch
import torch.nn as nn
class SpatialAttention(nn.Module):
def __init__(self, in_channels, out_channels):
super(SpatialAttention, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.conv4 = nn.Conv2d(out_channels, 1, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU(inplace=True)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x1 = self.conv1(x)
x2 = self.conv2(x1)
x3 = self.conv3(x2)
x4 = self.conv4(x3)
x_att = self.sigmoid(x4)
out = x_att * x1
return out
```
这个空间注意力模块包含了四个卷积层和两个激活函数。其中,第一个卷积层和最后一个卷积层分别用于将输入的特征图的通道数转换为指定的输出通道数和将两个特征图相乘得到最终的输出。中间的两个卷积层用于提取输入特征图的空间信息,第二个卷积层采用3x3的卷积核,第三个卷积层同样采用3x3的卷积核。最后,使用 Sigmoid 函数进行归一化,得到输出的注意力矩阵,然后将其与输入的特征图相乘得到最终的输出。