pytorch代码实现在模型中加CBAM注意力模块与block并列
时间: 2023-12-09 13:02:24 浏览: 98
【深度学习】CBAM注意力机制实现Python源代码.zip
5星 · 资源好评率100%
在PyTorch中,可以通过继承`nn.Module`类来构建自定义的模型。要在模型中加入CBAM注意力模块和block,可以先定义一个CBAM注意力模块类和一个block类,然后在主模型中将它们并列使用。
以下是一个示例代码:
```python
import torch
import torch.nn as nn
class CBAM(nn.Module):
def __init__(self, channels, reduction=16):
super(CBAM, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False)
self.sigmoid_channel = nn.Sigmoid()
self.conv_after_concat = nn.Conv2d(2, 1, kernel_size=7, stride=1, padding=3)
self.sigmoid_spatial = nn.Sigmoid()
def forward(self, x):
# Channel attention
avg_out = self.fc2(self.relu(self.fc1(self.avg_pool(x))))
max_out = self.fc2(self.relu(self.fc1(self.max_pool(x))))
channel_out = self.sigmoid_channel(avg_out + max_out) * x
# Spatial attention
spatial_out = torch.cat([self.avg_pool(channel_out), self.max_pool(channel_out)], dim=1)
spatial_out = self.conv_after_concat(spatial_out)
spatial_out = self.sigmoid_spatial(spatial_out)
return channel_out * spatial_out
class Block(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1):
super(Block, self).__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride=1, padding=padding, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.cbam = CBAM(out_channels)
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.cbam(out)
out += residual
out = self.relu(out)
return out
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = nn.Sequential(
Block(in_channels=64, out_channels=64),
Block(in_channels=64, out_channels=64),
Block(in_channels=64, out_channels=64)
)
self.layer2 = nn.Sequential(
Block(in_channels=64, out_channels=128, stride=2),
Block(in_channels=128, out_channels=128),
Block(in_channels=128, out_channels=128),
Block(in_channels=128, out_channels=128)
)
self.layer3 = nn.Sequential(
Block(in_channels=128, out_channels=256, stride=2),
Block(in_channels=256, out_channels=256),
Block(in_channels=256, out_channels=256),
Block(in_channels=256, out_channels=256),
Block(in_channels=256, out_channels=256),
Block(in_channels=256, out_channels=256)
)
self.layer4 = nn.Sequential(
Block(in_channels=256, out_channels=512, stride=2),
Block(in_channels=512, out_channels=512),
Block(in_channels=512, out_channels=512)
)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512, 10)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
```
在这个示例代码中,我们定义了一个CBAM类和一个Block类,然后在MyModel中使用这两个类来构建模型。CBAM类实现了CBAM注意力模块,Block类实现了一个基本的残差块结构,包括卷积、批标准化、ReLU激活函数和CBAM注意力模块。在MyModel中,我们使用了几个Block类来构建一个完整的残差网络,并使用全局平均池化层和全连接层来实现分类任务。
阅读全文