swin transformer中的掩码机制
时间: 2023-11-10 16:00:45 浏览: 236
Swin Transformer中的掩码机制是一种用于处理图像序列的注意力机制。它通过将图像分成若干个滑动窗口,并在每个窗口中计算注意力权重,从而实现对图像的局部化处理。具体而言,掩码机制将每个滑动窗口与序列中的其他窗口进行注意力计算,以便突出窗口内的重要特征信息,并通过对窗口内特征进行自适应加权求和,来生成窗口级别的表示。这种窗口级别的表示可以更好地捕捉到图像中不同区域的特征,并提高模型在图像分类、目标检测等任务上的性能。
相关问题
Swin Transformer中掩码的作用
Swin Transformer中的掩码(mask)是用于指示哪些位置需要被注意力机制忽略的一种机制。在自注意力机制中,每个位置都会关注所有其他位置的信息。然而,在某些情况下,我们希望模型只关注特定位置或特定类型的信息,而忽略其他位置或类型的信息。这时候,我们可以使用掩码来限制模型的注意力范围。
在Swin Transformer中,有两种常见的掩码方式:padding mask(填充掩码)和look-ahead mask(前瞻掩码)。填充掩码用于处理变长序列,通过将序列中的填充部分标记为0,使模型不会关注到这些填充部分。前瞻掩码则用于处理序列生成任务,通过将序列中未生成的部分标记为0,以避免模型在生成某个位置的时候依赖于其后面未生成的位置。
通过使用掩码,Swin Transformer能够更灵活地处理不同类型的序列数据,提高模型性能和效率。
swin transformer掩码
### Swin Transformer 中的掩码机制解析
在 Swin Transformer 的架构设计中,为了有效处理局部特征并减少计算复杂度,引入了基于移位窗口的操作。每个窗口内的自注意力操作仅限于该窗口内部的 token 之间,这使得不同窗口之间的交互被阻断。
当执行移位窗口划分时,原始图像会被分割成不重叠的窗口。对于标准窗口划分,在奇数层中保持不变;而在偶数层,则通过移动窗口位置来创建新的窗口组合方式[^1]。这种交替模式确保了相邻窗口间的信息流动。
#### 掩码生成过程
由于窗口间的 token 不应相互影响,因此需要构建相应的 attention mask 来屏蔽掉不属于当前窗口范围内的元素:
1. **初始化 Mask**: 对于每一个窗口大小 \(7 \times 7\) 或其他指定尺寸,会预先定义好一个全零矩阵作为基础模板。
2. **填充 Mask 值**: 当应用移位窗口策略后,某些原本属于同一窗口但在新布局下跨越边界的区域应当被标记出来。这些超出边界的部分将在后续计算过程中被忽略,即设置对应的 mask 项为负无穷大(-inf),从而让 softmax 函数将其权重压缩至接近0。
3. **应用于 Attention 计算**: 在多头自注意模块 Multi-head Self-Attention(MSA) 内部,输入序列经过线性变换投影得到 QKV 后,attention score 将乘上之前准备好的 mask 矩阵。这样可以保证只有来自相同窗口内部节点才能互相作用,而跨窗连接则完全失效。
```python
import torch
from einops import rearrange, repeat
def create_attn_mask(window_size, shift_size):
H, W = window_size
img_mask = torch.zeros((1, H, W, 1))
h_slices = (slice(0, -shift_size),
slice(-shift_size, None))
w_slices = (slice(0, -shift_size),
slice(-shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
img_mask[:, h, w, :] = cnt
cnt += 1
# Convert to mask matrix used during MHA computation.
mask_windows = rearrange(img_mask.view(1, H//window_size[0], window_size[0], W//window_size[1], window_size[1]), 'n wh ws1 ww ws2 -> n (wh ww) (ws1 ws2)')
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float("-inf")).masked_fill(attn_mask == 0, float(0.0))
return attn_mask
```
此代码片段展示了如何根据给定的 `window_size` 和 `shift_size` 创建用于遮蔽非本地化token间关联性的 attention mask 。具体来说,先构造了一个二维平面内各像素所属分区编号的地图 (`img_mask`) ,再依据其转换成为适用于批量运算的形式(`mask_windows`) 并最终形成完整的 attention mask 表达形式。
阅读全文
相关推荐
















