Swin Transformer 的基本原理
时间: 2024-12-27 13:18:19 浏览: 10
### Swin Transformer 的基本工作原理及架构
#### 三层级的特征提取过程
Swin Transformer 是一种基于窗口机制设计的视觉变换器模型,旨在解决传统 Transformer 在处理高分辨率图像时计算量过大的问题。该模型采用分层结构来逐步降低空间分辨率并增加通道数,从而有效地捕捉不同尺度下的特征[^1]。
#### 主要组件解析
##### 1. 图像划分与嵌入 (Patch Embedding)
输入图片被均匀分割成多个不重叠的小方块(patch),每个 patch 被视为一个 token 并映射到固定长度向量表示的空间中形成序列化数据作为后续编码的基础[^3]。
```python
class PatchEmbed(nn.Module):
""" Image to Patch Embedding """
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
# 定义线性投影操作用于转换维度大小
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size and W == self.img_size, \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size}*{self.img_size})."
# 将图像划分为 patches 后展平得到 tokens 序列
x = self.proj(x).flatten(2).transpose(1, 2)
return x
```
##### 2. 局部感知模块——移位窗口多头自注意力机制 (Shifted Window Multi-head Self Attention, SW-MSA)
为了减少全局范围内进行 attention 计算带来的高昂代价,Swin Transformer 提出了局部化的 window-based multi-head self-attention 方法,在此基础上进一步引入了 shift 操作使得相邻 windows 可以交互信息,增强了感受野的同时保持较低复杂度。
##### 3. 相对位置偏置 (Relative Position Bias)
考虑到绝对坐标对于旋转和平移敏感的问题,Swin Transformer 利用了相对距离定义的位置编码方式,这有助于提高模型泛化能力而不依赖于具体像素点的确切位置。
##### 4. 特征融合策略——Patch Merging
随着层数加深,网络逐渐缩小空间尺寸而扩展 channel 数目;为此特别设计了一种简单有效的下采样手段叫做 "Patch Merging", 即每隔两个元素取一次平均再经过一层全连接层实现降维重组的目的。
```python
import torch.nn.functional as F
def merge_patches(x, dim_reduction_factor=2):
N, L, D = x.size() # 输入形状为 [batch_size, num_patches, channels]
H = int(L ** .5) # 获取高度/宽度方向上原始patches的数量
x = x.view(N, H, H, D) # reshape 成二维平面形式方便后续操作
# 对每一对水平垂直相邻的四个单元格求均值后重新排列组合
merged_x = F.avg_pool2d(
input=x.permute(0, 3, 1, 2),
kernel_size=(dim_reduction_factor, dim_reduction_factor),
stride=(dim_reduction_factor, dim_reduction_factor))
new_L = (L // (dim_reduction_factor * dim_reduction_factor)) # 更新后的总patch数目
output = merged_x.flatten(start_dim=-2).permute(0, 2, 1) # 返回至初始格式 [batch_size, reduced_num_patches, increased_channels]
return output.reshape(-1, new_L, D*dim_reduction_factor**2)
```
阅读全文