写一个pytorch框架下的输入为(64,3,128,128)的VIT五分类网络带CBAM模块和残差结构
时间: 2024-05-14 09:13:59 浏览: 148
基于pytorch的分类网络
5星 · 资源好评率100%
以下是一个简单的实现,其中包含了VIT的基本结构、CBAM模块和残差结构:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class PatchEmbedding(nn.Module):
def __init__(self, in_channels=3, patch_size=16, emb_size=768):
super().__init__()
self.proj = nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
x = self.proj(x) # (B, E, P, P)
x = x.flatten(2) # (B, E, N)
x = x.transpose(1, 2) # (B, N, E)
return x
class MultiHeadAttention(nn.Module):
def __init__(self, emb_size, num_heads):
super().__init__()
self.emb_size = emb_size
self.num_heads = num_heads
self.head_dim = emb_size // num_heads
self.qkv = nn.Linear(emb_size, emb_size * 3, bias=False)
self.fc = nn.Linear(emb_size, emb_size)
def forward(self, x):
qkv = self.qkv(x) # (B, N, 3 * E)
qkv = qkv.reshape(-1, x.shape[1], 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # (3, B, H, N, D)
attn_weights = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5) # (B, H, N, N)
attn_weights = F.softmax(attn_weights, dim=-1)
attn_output = attn_weights @ v # (B, H, N, D)
attn_output = attn_output.transpose(1, 2) # (B, N, H, D)
attn_output = attn_output.flatten(2) # (B, N, E)
attn_output = self.fc(attn_output) # (B, N, E)
return attn_output
class MLP(nn.Module):
def __init__(self, emb_size, mlp_size):
super().__init__()
self.fc1 = nn.Linear(emb_size, mlp_size)
self.fc2 = nn.Linear(mlp_size, emb_size)
self.act = nn.GELU()
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
return x
class ResidualBlock(nn.Module):
def __init__(self, emb_size, num_heads, mlp_size):
super().__init__()
self.norm1 = nn.LayerNorm(emb_size)
self.attn = MultiHeadAttention(emb_size, num_heads)
self.norm2 = nn.LayerNorm(emb_size)
self.mlp = MLP(emb_size, mlp_size)
def forward(self, x):
x1 = self.norm1(x)
x2 = self.attn(x1)
x = x + x2
x1 = self.norm2(x)
x2 = self.mlp(x1)
x = x + x2
return x
class CBAM(nn.Module):
def __init__(self, emb_size, num_heads):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool1d(1)
self.max_pool = nn.AdaptiveMaxPool1d(1)
self.fc1 = nn.Linear(emb_size, emb_size // 2)
self.act = nn.ReLU(inplace=True)
self.fc2 = nn.Linear(emb_size // 2, emb_size)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg = self.avg_pool(x.transpose(-1, -2)).transpose(-1, -2)
max = self.max_pool(x.transpose(-1, -2)).transpose(-1, -2)
x = torch.cat([avg, max], dim=-1)
x = self.fc1(x)
x = self.act(x)
x = self.fc2(x)
x = self.sigmoid(x)
x = x.unsqueeze(-1) * x.unsqueeze(-2)
x = x.expand(-1, -1, x.shape[-2], x.shape[-2])
return x
class ViT_CBAM(nn.Module):
def __init__(self, in_channels=3, patch_size=16, emb_size=768, num_heads=12, num_classes=5):
super().__init__()
self.patch_embedding = PatchEmbedding(in_channels, patch_size, emb_size)
self.dim = int((128 // patch_size) ** 2 * emb_size)
self.cls_token = nn.Parameter(torch.randn(1, 1, emb_size))
self.pos_embedding = nn.Parameter(torch.randn(1, self.dim + 1, emb_size))
self.dropout = nn.Dropout(0.5)
self.cbam = CBAM(emb_size, num_heads)
self.blocks = nn.Sequential(*[
ResidualBlock(emb_size, num_heads, emb_size * 4)
for _ in range(6)
])
self.head = nn.Linear(emb_size, num_classes)
def forward(self, x):
x = self.patch_embedding(x)
cls_token = self.cls_token.expand(x.shape[0], -1, -1)
x = torch.cat([cls_token, x], dim=1)
x = x + self.pos_embedding
x = self.dropout(x)
x = self.blocks(x)
x = self.cbam(x)
x = x.mean(dim=1)
x = self.head(x)
return x
```
以上是一个基本的VIT五分类网络带CBAM模块和残差结构的实现。可以通过修改模型中的各种参数来满足不同的需求。
阅读全文