写一个pytorch框架下输入(64,3,128,128)的带CBAM、残差结构、BN层和dropout层VIT五分类网络不报错 每一层维度正确
时间: 2024-05-03 16:22:17 浏览: 207
基于pytorch的分类网络
5星 · 资源好评率100%
以下是一个符合要求的VIT五分类网络:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class MLP(nn.Module):
def __init__(self, in_features, hidden_features, out_features):
super().__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.fc2 = nn.Linear(hidden_features, out_features)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.dropout(x, p=0.1, training=self.training)
x = self.fc2(x)
return x
class Attention(nn.Module):
def __init__(self, dim, heads=8, dropout=0.1):
super().__init__()
self.heads = heads
self.scale = dim ** -0.5
self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
self.to_out = nn.Linear(dim, dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
b, n, _, h = *x.shape, self.heads
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: t.reshape(b, n, h, -1).transpose(1, 2), qkv)
dots = torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
attn = dots.softmax(dim=-1)
attn = self.dropout(attn)
out = torch.einsum('bhij,bhjd->bhid', attn, v)
out = out.transpose(1, 2).reshape(b, n, -1)
out = self.to_out(out)
out = self.dropout(out)
return out
class Residual(nn.Module):
def __init__(self, fn):
super().__init__()
self.fn = fn
def forward(self, x):
return self.fn(x) + x
class PreNorm(nn.Module):
def __init__(self, dim, fn):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.fn = fn
def forward(self, x):
return self.fn(self.norm(x))
class FeedForward(nn.Module):
def __init__(self, dim, hidden_dim, dropout=0.1):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, dim),
nn.Dropout(dropout)
)
def forward(self, x):
return self.net(x)
class CBAM(nn.Module):
def __init__(self, in_features, reduction_ratio=16):
super().__init__()
self.in_features = in_features
self.reduction_ratio = reduction_ratio
self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
self.max_pool = nn.AdaptiveMaxPool2d((1,1))
self.fc1 = nn.Linear(in_features, in_features // reduction_ratio)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(in_features // reduction_ratio, in_features)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
b, c, _, _ = x.size()
avg_out = self.avg_pool(x).view(b, c)
avg_out = self.fc2(self.relu(self.fc1(avg_out)))
max_out = self.max_pool(x).view(b, c)
max_out = self.fc2(self.relu(self.fc1(max_out)))
out = avg_out + max_out
out = self.sigmoid(out).view(b, c, 1, 1)
return x * out
class VIT(nn.Module):
def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, dropout):
super().__init__()
assert image_size % patch_size == 0, 'image size must be divisible by patch size'
num_patches = (image_size // patch_size) ** 2
patch_dim = 3 * patch_size ** 2
self.patch_size = patch_size
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
self.patch_to_embedding = nn.Linear(patch_dim, dim)
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
self.dropout = nn.Dropout(dropout)
self.transformer = nn.ModuleList([
Residual(PreNorm(dim, Attention(dim, heads=heads, dropout=dropout))),
Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout=dropout)))
])
self.to_cls_token = nn.Identity()
self.mlp_head = MLP(dim, hidden_features=512, out_features=num_classes)
self.cbam = CBAM(in_features=dim)
def forward(self, x):
b, _, h, w = x.shape
p = self.patch_size
# assert input size is divisible by patch size
assert h % p == 0 and w % p == 0, f'image size {h}x{w} not divisible by patch size {p}'
# convert image to patches
x = x.reshape(b, 3, h//p, p, w//p, p)
x = x.permute(0, 2, 4, 1, 3, 5)
x = x.reshape(b, -1, 3 * p ** 2)
# add cls token and convert patches to embeddings
cls_token = self.cls_token.expand(b, -1, -1)
x = torch.cat((cls_token, x), dim=1)
x = self.patch_to_embedding(x)
# add positional embedding
x = x + self.pos_embedding
# apply transformer
for transformer_block in self.transformer:
x = transformer_block(x)
# extract cls token and apply MLP head
x = self.to_cls_token(x[:, 0])
x = self.cbam(x.permute(0, 2, 1).reshape(b, -1, h//p, w//p)).permute(0, 2, 3, 1).reshape(b, -1)
x = self.dropout(x)
x = self.mlp_head(x)
return x
```
使用方式:
```python
import torch
model = VIT(image_size=128, patch_size=16, num_classes=5, dim=256, depth=6, heads=8, mlp_dim=512, dropout=0.1)
input_tensor = torch.randn(64, 3, 128, 128)
output_tensor = model(input_tensor)
print(output_tensor.shape) # 输出 (64, 5)
```
阅读全文