将swin transformer模型中的Positional Encoding形式替换为Sinusoidal Positional Encoding,并给出相应的pytorch代码
时间: 2023-12-14 10:36:21 浏览: 124
Swin-Transformer-main(截止到2022年12月23日).zip
Swin Transformer使用的是Learned Positional Encoding,如果要将其替换为Sinusoidal Positional Encoding,需要进行一些修改。
首先,可以定义一个Sinusoidal Positional Encoding的函数,如下所示:
```python
import math
import torch
import torch.nn as nn
class SinusoidalPositionalEmbedding(nn.Module):
def __init__(self, d_model, max_len=512):
super().__init__()
self.d_model = d_model
self.max_len = max_len
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0).transpose(0, 1)
self.register_buffer('pe', pe)
def forward(self, x):
x = x * math.sqrt(self.d_model)
seq_len = x.size(1)
pe = self.pe[:seq_len, :]
pe = pe.repeat(x.size(0), 1, 1)
x = x + pe.to(x.device)
return x
```
然后,在Swin Transformer的构造函数中,将使用Learned Positional Encoding的部分替换为Sinusoidal Positional Encoding,如下所示:
```python
import torch
import torch.nn as nn
from einops.layers.torch import Rearrange
class SwinTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=4, in_chans=3, num_classes=1000, embed_dim=96, depths=[2, 2, 18, 2], num_heads=[3, 6, 12, 24], window_size=7, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):
super().__init__()
norm_layer = nn.LayerNorm
self.num_classes = num_classes
self.num_features = self.embed_dim = embed_dim
# stochastic depth decay rule
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
# patch embedding
self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.norm1 = norm_layer(embed_dim)
# pos embedding
self.pos_embed = SinusoidalPositionalEmbedding(embed_dim, max_len=(img_size//patch_size)**2+1)
# swin transformer blocks
self.blocks = nn.ModuleList([
SwinTransformerBlock(dim=embed_dim, num_heads=num_heads[i], window_size=window_size, shift_size=window_size // 2 if i == 0 else 0,
mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dpr[sum(depths[:i]):sum(depths[:i+1])])
for i in range(len(depths))])
# norm before classifier
self.norm2 = norm_layer(embed_dim)
# classification head
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x):
x = self.patch_embed(x)
x = self.norm1(x)
x = x.flatten(2).transpose(1, 2)
x = self.pos_embed(x)
for i, blk in enumerate(self.blocks):
x = blk(x)
x = self.norm2(x)
return x
def forward(self, x):
x = self.forward_features(x)
x = x.mean(dim=1) # global average pooling
x = self.head(x)
return x
```
这样,就完成了Swin Transformer模型中Positional Encoding形式的替换。
阅读全文