写一个pytorch框架下输入(64,3,128,128)的带CBAM的VIT五分类网络
时间: 2024-05-29 16:12:22 浏览: 136
import torch
import torch.nn as nn
import torch.nn.functional as F
class ConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding=kernel_size//2)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class SEBlock(nn.Module):
def __init__(self, channels, reduction=16):
super(SEBlock, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
y = self.avg_pool(x)
y = self.fc1(y)
y = self.relu(y)
y = self.fc2(y)
y = self.sigmoid(y)
return x * y
class CBAMBlock(nn.Module):
def __init__(self, channels, reduction=16):
super(CBAMBlock, self).__init__()
self.se = SEBlock(channels, reduction)
self.conv = nn.Conv2d(channels, channels, kernel_size=3, stride=1, padding=1)
self.bn = nn.BatchNorm2d(channels)
self.relu = nn.ReLU(inplace=True)
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
nn.Conv2d(channels, channels//2, kernel_size=1, stride=1, padding=0),
nn.ReLU(inplace=True),
nn.Conv2d(channels//2, channels, kernel_size=1, stride=1, padding=0),
nn.Sigmoid()
)
def forward(self, x):
y = self.se(x)
w = self.avg_pool(y) + self.max_pool(y)
w = self.fc(w)
y = x * w
y = self.conv(y)
y = self.bn(y)
y = self.relu(y)
return y
class MLP(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None):
super(MLP, self).__init__()
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = nn.GELU()
self.fc2 = nn.Linear(hidden_features, out_features)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
class PatchEmbedding(nn.Module):
def __init__(self, img_size=128, patch_size=16, in_chans=3, embed_dim=768):
super(PatchEmbedding, self).__init__()
self.img_size = img_size
self.patch_size = patch_size
self.n_patches = (img_size // patch_size) ** 2
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x) # (B, E, n_patchs, n_patchs)
x = x.flatten(2) # (B, E, n_patches)
x = x.transpose(1, 2) # (B, n_patches, E)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=True, attn_drop=0., proj_drop=0.):
super(Attention, self).__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim*3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, n, C = x.shape
qkv = self.qkv(x).reshape(B, n, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
# make torchscript happy (cannot use tensor as tuple)
q, k, v = qkv[0], qkv[1], qkv[2]
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, n, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0., drop_path=0.):
super(Block, self).__init__()
self.norm1 = nn.LayerNorm(dim)
self.attn = Attention(
dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.drop_path = nn.Dropout(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = nn.LayerNorm(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(dim, mlp_hidden_dim),
nn.GELU(),
nn.Dropout(drop),
nn.Linear(mlp_hidden_dim, dim),
nn.Dropout(drop),
)
def forward(self, x):
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class VisionTransformer(nn.Module):
def __init__(self, img_size=128, patch_size=16, in_chans=3, num_classes=5, embed_dim=768, depth=12,
num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.):
super(VisionTransformer, self).__init__()
self.num_classes = num_classes
self.patch_embed = PatchEmbedding(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
num_patches = self.patch_embed.n_patches
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
self.drop = nn.Dropout(drop_rate)
self.blocks = nn.ModuleList([
Block(
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=drop_path_rate)
for i in range(depth)])
self.norm = nn.LayerNorm(embed_dim)
self.fc = nn.Linear(embed_dim, num_classes)
trunc_normal_(self.pos_embed, std=.02)
trunc_normal_(self.cls_token, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.drop(x)
for block in self.blocks:
x = block(x)
x = self.norm(x)
cls_tokens = x[:, 0]
x = self.fc(cls_tokens)
return x
def trunc_normal_(tensor, mean=0., std=1.):
tensor.normal_().fmod_(2).mul_(std).add_(mean) # not a perfect approximation
model = VisionTransformer().cuda()
x = torch.randn(64, 3, 128, 128).cuda()
output = model(x)
print(output.shape)
阅读全文