drsn pytorch
时间: 2023-09-28 08:04:02 浏览: 163
DRSN是一种使用了残差收缩层的ResNet变体,在PyTorch中有相应的实现代码。残差收缩层的结构包括两次卷积操作后取绝对值,然后经过全局平均池化(GAP)将通道数缩减为1,再经过两次全连接层得到z,再经过Sigmoid函数得到一个值。最后,将GAP的结果取平均后与z相乘,得到软阈值化的结果,再与原输入进行阈值化。在PyTorch中的代码实现如下所示:
```python
class Shrinkage(nn.Module):
def __init__(self, gap_size, channel):
super(Shrinkage, self).__init__()
self.gap = nn.AdaptiveAvgPool2d(gap_size)
self.fc = nn.Sequential(
nn.Linear(channel, channel),
nn.BatchNorm1d(channel),
nn.ReLU(inplace=True),
nn.Linear(channel, 1),
nn.Sigmoid(),
)
def forward(self, x):
x_raw = x
x = torch.abs(x)
x_abs = x
x = self.gap(x)
x = torch.flatten(x, 1)
average = torch.mean(x, dim=1, keepdim=True)
x = self.fc(x)
x = torch.mul(average, x)
x = x.unsqueeze(2).unsqueeze(2)
sub = x_abs - x
zeros = sub - sub
n_sub = torch.max(sub, zeros)
x = torch.mul(torch.sign(x_raw), n_sub)
return x
```
这段代码是对DRSN中残差收缩层的PyTorch实现。其中,输入x经过取绝对值操作后,通过全局平均池化(GAP)得到一个向量,再经过两次全连接层得到z,再通过Sigmoid函数得到一个值。然后,将GAP的结果取平均后与z相乘,得到软阈值化的结果,再与原输入进行阈值化。这段代码可以用于PyTorch中的DRSN模型实现。<span class="em">1</span><span class="em">2</span><span class="em">3</span>
#### 引用[.reference_title]
- *1* *2* [深度残差收缩网络的PyTorch代码](https://blog.csdn.net/ifalivestudy/article/details/112218925)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"]
- *3* [深度残差收缩网络(DRSN)理解、PyTorch代码复现(转载)](https://blog.csdn.net/zjy1015741747/article/details/123669502)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v92^chatsearchT3_1"}}] [.reference_item style="max-width: 50%"]
[ .reference_list ]
阅读全文