残差连接卷积神经网络
时间: 2025-01-06 12:34:31 浏览: 16
### 带有残差连接的卷积神经网络
#### 定义与背景
卷积神经网络(ConvNets 或 CNNs)在图像识别和分类等领域表现出色[^1]。传统上,这些模型通过堆叠多层来提取特征并进行预测。然而,在深层网络中,随着层数增加,梯度消失问题变得严重,导致训练困难。
#### 残差连接的作用
为了克服这一挑战,引入了带有跳跃连接的设计——即所谓的 **残差网络** (ResNet)。这种架构允许信息绕过某些层次结构直接传递给后续层。具体来说,残差块由两个主要部分组成:
- 主路径:执行标准卷积操作;
- 跳跃/捷径连接:将输入不经任何变换地加到输出上去;
这样做的好处是可以缓解退化现象,并使得更深的网络更容易优化。
```python
import torch.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super(ResidualBlock, self).__init__()
# Main path with two convolutions
self.main_path = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels)
)
# Shortcut connection may involve downsampling if necessary
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
identity = self.shortcut(x)
out = self.main_path(x)
out += identity
return F.relu(out)
# Example usage within a larger network definition
def make_layer(block, planes, blocks, stride=1):
layers = []
layers.append(block(in_planes, planes, stride))
in_planes = planes * block.expansion
for i in range(1, blocks):
layers.append(block(in_planes, planes))
return nn.Sequential(*layers)
```
上述代码展示了如何实现一个简单的残差模块 `ResidualBlock` 及其应用方式。注意这里包含了两种情况下的处理逻辑:当维度匹配时可以直接相加;如果不匹配,则需先调整尺寸再叠加。
阅读全文