bottleneck结构图
时间: 2025-01-02 13:28:04 浏览: 17
### Bottleneck 结构概述
Bottleneck结构广泛应用于现代深度学习网络中,特别是像ResNet这样的深层网络。这一结构旨在减少计算量的同时保持甚至提升性能。在典型的Bottleneck单元内,输入先经过1×1卷积核压缩通道数,随后通过3×3卷积核处理空间信息,最后再利用另一个1×1卷积核恢复原始维度[^2]。
对于ResNet中的Bottleneck模块而言,通常采用的是(1 × 1, 64)-(3 × 3, 64)-(1 × 1, 256)的形式,在这里第一个1×1卷积用于降维,中间的3×3负责捕捉局部特征,而最后一个1×1则用来升维以便后续操作。
```python
import torch.nn as nn
class Bottleneck(nn.Module):
expansion = 4
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(Bottleneck, self).__init__()
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
self.relu = nn.ReLU(inplace=True)
self.downsample = downsample
self.stride = stride
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
return out
```
上述代码展示了如何实现一个标准的Bottleneck层,其中包含了三个连续的不同尺寸卷积操作以及批标准化和激活函数的应用。
阅读全文