EfficientNetV2 中的 se模块代码
时间: 2023-05-31 19:02:16 浏览: 72
将SE模块与密集连接网络结合的代码
EfficientNetV2 中的 se 模块代码如下:
```
class SE(nn.Module):
def __init__(self, in_chs, reduction_ratio=4):
super(SE, self).__init__()
self.squeeze = nn.AdaptiveAvgPool2d(1)
self.excitation = nn.Sequential(
nn.Linear(in_chs, in_chs // reduction_ratio, bias=False),
nn.ReLU(inplace=True),
nn.Linear(in_chs // reduction_ratio, in_chs, bias=False),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.squeeze(x).view(b, c)
y = self.excitation(y).view(b, c, 1, 1)
return x * y
```
该模块包含 squeeze 和 excitation 两个部分。squeeze 是一个全局平均池化层,对输入的特征图进行池化,输出一个 C×1×1 的张量。excitation 是一个包含两个全连接层和 ReLU 激活函数的序列,将输入的 C×1×1 张量压缩成 C//reduction_ratio×1×1 的张量,再通过 Sigmoid 函数进行激活,输出一个 C×1×1 的张量。最后将输入的特征图和输出的张量进行乘积,得到经过 se 模块处理后的特征图。
阅读全文