一维SE代码
时间: 2023-07-02 18:18:33 浏览: 74
下面是一维SE(Squeeze-and-Excitation)代码的示例,假设输入张量的形状为 `(batch_size, channels, seq_len)`:
```python
import torch.nn as nn
class SEBlock(nn.Module):
def __init__(self, channels, reduction=16):
super(SEBlock, self).__init__()
self.avgpool = nn.AdaptiveAvgPool1d(1)
self.fc1 = nn.Linear(channels, channels // reduction)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Linear(channels // reduction, channels)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
batch_size, channels, seq_len = x.size()
y = self.avgpool(x).view(batch_size, channels)
y = self.fc1(y)
y = self.relu(y)
y = self.fc2(y)
y = self.sigmoid(y).view(batch_size, channels, 1)
return x * y.expand_as(x)
```
在这里,我们定义了一个 `SEBlock` 类,它包含了一个全局平均池化层、两个全连接层和一个 sigmoid 激活函数。其中,`channels` 表示输入张量的通道数,`reduction` 表示 SE 模块中的通道缩减比例,默认为 16。在前向传播过程中,我们首先对输入张量进行全局平均池化,然后将结果输入到两个全连接层中进行特征转换。最后,我们使用 sigmoid 函数将转换后的特征映射到 0 到 1 的范围内,并将其扩展到输入张量的形状,与输入张量进行相乘操作,得到最终的输出结果。
阅读全文