写一个torch框架下输入(64,3,128,128)的带CBAM、BN层和dropout层VIT五分类网络要求效率高性能好可正常运行
时间: 2024-05-24 21:11:26 浏览: 98
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from torchsummary import summary
# CBAM模块
class CBAM(nn.Module):
def __init__(self, channels, reduction=16):
super(CBAM, self).__init__()
self.channels = channels
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, bias=False)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, bias=False)
self.softmax = nn.Softmax(dim=1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.avg_pool(x)
max_out = self.max_pool(x)
avg_out = self.fc2(self.relu(self.fc1(avg_out)))
max_out = self.fc2(self.relu(self.fc1(max_out)))
out = avg_out + max_out
out = self.sigmoid(out)
out = out.unsqueeze(-1).expand_as(x)
out = x * out
return out
# VIT块
class ViTBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, num_heads, dropout):
super(ViTBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
self.bn1 = nn.BatchNorm2d(out_channels)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride)
self.bn2 = nn.BatchNorm2d(out_channels)
self.cbam = CBAM(out_channels)
self.norm = nn.LayerNorm(out_channels)
self.mlp = nn.Sequential(
nn.Linear(out_channels, 4 * out_channels),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(4 * out_channels, out_channels),
nn.Dropout(dropout)
)
self.attention = nn.MultiheadAttention(out_channels, num_heads, dropout)
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.bn1(x)
x = F.gelu(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.cbam(x)
x = self.norm(residual + x)
x = rearrange(x, 'b c h w -> h w b c')
attn_output, _ = self.attention(x, x, x)
attn_output = rearrange(attn_output, 'h w b c -> b c h w')
x = self.norm(x + attn_output)
x = rearrange(x, 'b c h w -> b h w c')
x = self.mlp(x)
return x
# VIT模型
class ViT(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride, num_heads, num_classes, dropout):
super(ViT, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride)
self.bn1 = nn.BatchNorm2d(out_channels)
self.blocks = nn.Sequential(
ViTBlock(out_channels, out_channels, kernel_size, stride, num_heads, dropout),
ViTBlock(out_channels, out_channels, kernel_size, stride, num_heads, dropout),
ViTBlock(out_channels, out_channels, kernel_size, stride, num_heads, dropout),
ViTBlock(out_channels, out_channels, kernel_size, stride, num_heads, dropout),
ViTBlock(out_channels, out_channels, kernel_size, stride, num_heads, dropout)
)
self.pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Linear(out_channels, num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = F.gelu(x)
x = self.blocks(x)
x = self.pool(x)
x = x.flatten(start_dim=1)
x = self.fc(x)
return x
# 测试
if __name__ == '__main__':
model = ViT(in_channels=3, out_channels=256, kernel_size=3, stride=1, num_heads=8, num_classes=5, dropout=0.2)
summary(model, (3, 128, 128))
阅读全文