Swin Transformer model代码
时间: 2023-10-25 13:19:46 浏览: 120
以下是Swin Transformer的PyTorch代码实现,包括Swin Transformer的模型定义和训练过程:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class SwinBlock(nn.Module):
"""Swin Transformer Block"""
def __init__(self, dim, num_heads, window_size=7, shift_size=0):
super().__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = nn.MultiheadAttention(dim, num_heads)
self.norm2 = nn.LayerNorm(dim)
self.mlp = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.GELU(),
nn.Linear(dim * 4, dim),
)
self.window_size = window_size
self.shift_size = shift_size
if window_size == 1 and shift_size == 0:
self.window_attn = None
else:
self.window_attn = nn.MultiheadAttention(dim, num_heads)
def forward(self, x):
res = x
x = self.norm1(x)
if self.window_attn is not None:
b, n, d = x.shape
assert n % self.window_size == 0, "sequence length must be divisible by window size"
x = x.reshape(b, n // self.window_size, self.window_size, d)
x = x.permute(0, 2, 1, 3)
x = x.reshape(b * self.window_size, n // self.window_size, d)
window_res = x
x = self.window_attn(x, x, x)[0]
x = x.reshape(b, self.window_size, n // self.window_size, d)
x = x.permute(0, 2, 1, 3)
x = x.reshape(b, n, d)
x += window_res
x = x + self.attn(x, x, x)[0]
x = res + x
res = x
x = self.norm2(x)
x = x + self.mlp(x)
x = res + x
if self.shift_size > 0:
x = F.pad(x, (0, 0, 0, 0, self.shift_size, 0))
x = x[:, :-self.shift_size, :]
return x
class SwinTransformer(nn.Module):
"""Swin Transformer Model"""
def __init__(self, img_size, patch_size, in_chans, num_classes, embed_dim, depths, num_heads, window_size=7, shift_size=0):
super().__init__()
assert img_size % patch_size == 0, "image size must be divisible by patch size"
num_patches = (img_size // patch_size) ** 2
patch_dim = in_chans * patch_size ** 2
self.patch_size = patch_size
# Patch Embeddings
self.patch_embed = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
# Stages
self.stages = nn.ModuleList([
nn.Sequential(*[
SwinBlock(embed_dim, num_heads, window_size, shift_size)
for _ in range(depth)
])
for depth in depths
])
# Classifier Head
self.norm = nn.LayerNorm(embed_dim)
self.avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(embed_dim, num_classes)
def forward(self, x):
x = self.patch_embed(x)
b, c, h, w = x.shape
x = x.flatten(2).transpose(1, 2)
x = x + self.pos_embed
x = self.norm(x)
for stage in self.stages:
x = stage(x)
x = self.avgpool(x.transpose(1, 2).reshape(b, c, h // self.patch_size, w // self.patch_size))
x = x.flatten(1)
x = self.fc(x)
return x
```
训练过程:
```python
import torch.optim as optim
# Define Model
model = SwinTransformer(
img_size=224,
patch_size=4,
in_chans=3,
num_classes=10,
embed_dim=96,
depths=[2, 2, 6, 2],
num_heads=6,
window_size=7,
shift_size=2,
)
model.cuda()
# Define Loss and Optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)
# Train Model
for epoch in range(num_epochs):
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
inputs, labels = inputs.cuda(), labels.cuda()
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
print(f"Epoch {epoch+1}: Training Loss = {running_loss/len(trainloader)}")
```
阅读全文