swin transformer块
时间: 2025-01-03 08:28:57 浏览: 12
### Swin Transformer 架构详解
#### 层次化视觉变换器结构
Swin Transformer采用层次化的架构来处理输入图像,这种设计使得模型能够捕捉不同尺度的信息。具体来说,该网络由多个阶段组成,每个阶段负责提取特定分辨率下的特征图[^1]。
#### 移动窗口机制
为了提高局部建模能力并减少全局依赖带来的计算复杂度,Swin Transformer提出了移动窗口自注意机制(Shifted Window-based Self-Attention)。此方法允许在同一层内相邻区域之间建立联系,而不需要增加过多参数量或降低速度表现[^2]。
#### 补丁划分与线性嵌入
在进入第一个Transformer block之前,原始图片会被分割成不重叠的小方块——即“patches”,这些patch随后被映射到固定维度向量空间作为token表示形式用于后续处理过程中的交互操作。
#### 逐级下采样策略
除了基本构建单元外,还存在一种特殊的跨阶段连接方式叫做Patch Merging Layer,它可以在不影响其他部分正常工作的前提下有效地缩小感受野大小从而达到降维目的;与此同时也能增强多尺度表征学习的效果。
```python
import torch.nn as nn
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
class PatchEmbed(nn.Module):
""" Image to Patch Embedding """
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
self.img_size = img_size
self.patch_size = patch_size
self.patches_resolution = patches_resolution
self.num_patches = patches_resolution[0] * patches_resolution[1]
self.in_chans = in_chans
self.embed_dim = embed_dim
# 定义卷积层完成从像素到补丁的转换
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
assert H == self.img_size[0] and W == self.img_size[1], \
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
# 执行投影得到补丁序列
x = self.proj(x).flatten(2).transpose(1, 2)
return x
```
上述代码展示了如何定义一个简单的`PatchEmbed`类来进行图像至补丁的转化工作,并将其转化为适合送入Transformer编码器的形式。
阅读全文