swin transformer的block
时间: 2024-12-27 20:28:33 浏览: 5
### Swin Transformer Block 架构详解
Swin Transformer 的核心在于其独特的窗口划分机制以及局部和全局建模能力。具体来说,Swin Transformer 将输入特征图划分为不重叠的局部窗口,在这些窗口内执行自注意力计算[^1]。
#### 局部窗口内的多头自注意力(W-MSA)
在每个窗口内部,采用标准的多头自注意力机制来捕捉局部区域的信息交互。这种设计使得模型能够在保持较低计算复杂度的同时有效地提取局部特征。
```python
class WindowAttention(nn.Module):
"""Window based multi-head self attention (W-MSA) module with relative position bias.
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
super().__init__()
# 定义参数...
def forward(self, x):
B_, N, C = x.shape
# 计算QKV矩阵并进行缩放点积操作...
return x
```
#### 移动窗口策略(Shifted Windows)
为了增强跨窗口信息交流的能力,引入了移位窗口技术。通过交替使用常规窗口和平移后的窗口布局,实现了相邻窗口之间的信息传递,从而增强了整体感受野范围。
```python
def cyclic_shift(x, shift_size):
H, W = x.size()[-2:]
shifted_x = torch.roll(x, shifts=(-shift_size, -shift_size), dims=(1, 2))
img_mask = torch.zeros((1, H, W, 1)).cuda()
h_slices = slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None)
w_slices = slice(0, -window_size),
slice(-window_size, -shift_size),
slice(-shift_size, None)
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
mask_windows = window_partition(img_mask, window_size).view(-1, window_size * window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
return shifted_x, attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
```
#### 多尺度融合模块
除了上述主要组件外,还包含了线性投影层、LayerNorm 归一化处理单元等辅助部分,共同构成了完整的 Swin Transformer 块架构。
阅读全文