windows-Attention代码讲解
我想您可能是想问关于 Self-Attention 的代码讲解,因为 Self-Attention 是 Transformer 中的重要组成部分,Windows-Attention 不是一个通用的术语。以下是一个简单的 Self-Attention 代码讲解:
- 定义 Self-Attention 层
在代码中,定义了一个 Self-Attention 层。Self-Attention 层包含一个全连接层,用于将输入的特征向量映射到一个低维空间,然后对映射后的特征向量进行注意力计算。在这个例子中,使用了一个简单的注意力计算方法,即将映射后的特征向量相乘,然后进行归一化,得到注意力权重。最后将注意力权重与映射后的特征向量相乘,得到输出特征向量。
- 输入特征向量
在代码中,定义了一个输入特征向量 x,它是一个二维张量,包含了 batch_size 个样本和每个样本的特征向量。在这个例子中,特征向量的维度为 d_model。
- 调用 Self-Attention 层
在代码中,调用了定义的 Self-Attention 层,并将输入特征向量 x 作为参数传递给 Self-Attention 层。Self-Attention 层计算出注意力权重和输出特征向量,并将输出特征向量返回。
- 输出特征向量
在代码中,输出了 Self-Attention 层计算得到的输出特征向量。
以上是一个简单的 Self-Attention 代码讲解,如果您有任何问题或需要更详细的解释,请随时询问。
swin transformer源代码讲解
Swin Transformer 源码解析
一、整体框架概述
Swin Transformer 是一种基于窗口的分层视觉Transformer,其设计旨在提高计算效率并增强局部特征表达能力。通过引入移位窗口机制和层次化结构,使得模型能够在不同尺度上捕捉空间信息[^1]。
二、核心组件剖析
(一)Patch Partitioning (打补丁划分)
为了适应图像数据的特点,在输入阶段会将原始图片分割成多个不重叠的小方块(patch),这些patch会被展平作为后续处理的基本单元。具体来说,给定大小为 (H \times W) 的RGB 图像,假设 patch size 设置为(P\times P), 那么最终得到的 patches 数目就是 ((H/P)\times(W/P))[^2]。
import torch.nn as nn
class PatchEmbed(nn.Module):
""" Image to Patch Embedding """
def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96):
super().__init__()
self.patch_embed = nn.Conv2d(in_channels=in_chans,
out_channels=embed_dim,
kernel_size=(patch_size, patch_size),
stride=(patch_size, patch_size))
def forward(self, x):
B, C, H, W = x.shape
assert H % self.patch_size == 0 and W % self.patch_size == 0,\
f"Input image size ({H}*{W}) should be divisible by patch size {self.patch_size}"
x = self.patch_embed(x).flatten(2).transpose(1, 2) # BCHW -> BNC
return x
(二)Window-based Multi-head Self Attention (W-MSA)
不同于传统全局自注意力机制,这里采用的是限定范围内的多头自我注意操作——即只考虑当前窗口内元素之间的关系。这样做不仅减少了内存消耗还加快了训练速度。当涉及到跨窗交互时,则借助于shifted window strategy来实现[^3]。
from timm.models.layers import DropPath, trunc_normal_
def window_partition(x, window_size):
""" 将feature map按照window_size切分成若干个小窗口"""
...
def window_reverse(windows, window_size, H, W):
""" 把之前按window partition后的tensor重新拼接回原来的尺寸"""
...
class WindowAttention(nn.Module):
r""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def __init__(self, dim, window_size, num_heads, qkv_bias=True, attn_drop=0., proj_drop=0.):
super().__init__()
...
(三)Shifted Window Mechanism
此部分主要解决相邻窗口间的信息交流问题。简单来讲就是在奇数层执行一次水平/垂直方向上的偏移后再做w-msa;偶数层则保持不动继续常规流程。这样既保留了一定程度的空间连续性又不会造成过多额外开销[^4]。
class SwinTransformerBlock(nn.Module):
def __init__(...):
...
def forward(...):
shortcut = x
if shift_size > 0:# 判断是否需要进行shift operation
shifted_x = roll(shifted_x,-shift_size,dims=(-2))# 对应论文里的roll function
x_windows = window_partition(shifted_x,...)# 进行window partition
attn_windows = self.attn(x_windows,...)# 执行attention mechanism
shifted_x = window_reverse(attn_windows,...)# 反向恢复到原图结构
if shift_size > 0:
x = roll(shifted_x,+shift_size,dims=(-2))# 如果进行了shift就要记得还原回去哦~
else :
x = shifted_x
x = shortcut + self.drop_path(self.norm1(x))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
三、总结
通过对以上几个重要组成部分的学习可以看出,Swin Transformer 在继承了经典transformer优点的基础上做了很多针对性优化,特别是在提升计算性能方面表现尤为突出。对于想要深入了解该网络内部工作原理的人来说,官方给出的GitHub仓库提供了详尽完整的源代码资源可供参考学习[^5]。
相关推荐






