Swin Transformer相比其他Transformer架构有哪些优势?
时间: 2023-12-13 22:31:33 浏览: 158
Swin Transformer是一种新型的Transformer架构,相比其他Transformer架构,它有以下优势:
1.更高的计算效率:Swin Transformer使用了分层的结构,将图像分成多个小块,每个小块内部进行自注意力计算,然后再将小块组合起来进行全局自注意力计算,这种分层的结构使得计算效率更高。
```python
# Swin Transformer中的分层结构
class SwinTransformerBlock(nn.Module):
def __init__(self, dim, num_heads, window_size, shift_size=0, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.num_heads = num_heads
self.window_size = window_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
self.qkv_bias = qkv_bias
self.qk_scale = qk_scale
self.drop = drop
self.attn_drop = attn_drop
self.drop_path = drop_path
self.act_layer = act_layer
self.norm_layer = norm_layer
self.init_layers()
def init_layers(self):
self.norm1 = self.norm_layer(self.dim)
self.attn = WindowAttention(
dim=self.dim, window_size=self.window_size,
num_heads=self.num_heads, qkv_bias=self.qkv_bias, qk_scale=self.qk_scale,
attn_drop=self.attn_drop, proj_drop=self.drop)
self.drop_path = DropPath(self.drop_path) if self.drop_path > 0. else nn.Identity()
self.norm2 = self.norm_layer(self.dim)
mlp_hidden_dim = int(self.dim * self.mlp_ratio)
self.mlp = Mlp(in_features=self.dim, hidden_features=mlp_hidden_dim, act_layer=self.act_layer, drop=self.drop)
def forward(self, x):
H, W = x.shape[-2:]
# cyclic shift
if self.shift_size > 0:
x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
# partition windows
x_windows = x.unfold(1, self.window_size, self.window_size).unfold(2, self.window_size, self.window_size)
x_windows = x_windows.contiguous().view(-1, self.dim, self.window_size, self.window_size)
# W-MSA/SW-MSA
attn_windows = self.attn(self.norm1(x_windows))
# merge windows
attn_windows = attn_windows.view(-1, self.num_heads, self.window_size * self.window_size, attn_windows.shape[-1])
attn_windows = attn_windows.permute(0, 2, 1, 3).reshape(-1, self.window_size, self.window_size, self.dim)
# reverse partition
x_out = torch.nn.functional.fold(attn_windows, output_size=(H, W), kernel_size=(self.window_size, self.window_size), stride=(self.window_size, self.window_size))
# FFN
x_out = x + self.drop_path(x_out)
x_out = x_out + self.drop_path(self.mlp(self.norm2(x_out)))
return x_out
```
2.更好的性能表现:Swin Transformer在多个图像分类任务上都取得了最先进的性能表现,例如ImageNet、CIFAR-10等。
3.更好的可扩展性:Swin Transformer可以很容易地扩展到更大的图像尺寸和更深的网络结构,而不会导致计算效率下降。
阅读全文
相关推荐


















