cbam注意力机制代码
时间: 2023-06-28 09:06:52 浏览: 159
CBAM(Convolutional Block Attention Module)是一种用于增强卷积神经网络(CNN)性能的注意力机制。下面是使用PyTorch实现CBAM注意力模块的示例代码:
``` python
import torch
import torch.nn as nn
import torch.nn.functional as F
class ChannelAttention(nn.Module):
def __init__(self, in_planes, ratio=16):
super(ChannelAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
self.relu1 = nn.ReLU()
self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
out = self.sigmoid(avg_out + max_out)
return out
class SpatialAttention(nn.Module):
def __init__(self, kernel_size=7):
super(SpatialAttention, self).__init__()
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
padding = 3 if kernel_size == 7 else 1
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
x = torch.cat([avg_out, max_out], dim=1)
x = self.conv1(x)
out = self.sigmoid(x)
return out
class CBAM(nn.Module):
def __init__(self, in_planes, ratio=16, kernel_size=7):
super(CBAM, self).__init__()
self.ca = ChannelAttention(in_planes, ratio)
self.sa = SpatialAttention(kernel_size)
def forward(self, x):
out = self.ca(x) * x
out = self.sa(out) * out
return out
```
在上面的代码中,我们首先定义了一个`ChannelAttention`类和一个`SpatialAttention`类。 `ChannelAttention`类实现通道注意力机制,计算输入特征图中每个通道的重要性,并将其乘以原始输入特征图。`SpatialAttention`类实现空间注意力机制,通过对输入特征图进行平均池化和最大池化得到两个特征图,并将它们合并成一个特征图,然后使用卷积和Sigmoid函数来计算每个像素的重要性,并将其乘以原始输入特征图。
我们最后定义了一个`CBAM`类,该类包含一个通道注意力模块和一个空间注意力模块,并将它们的输出乘以原始输入特征图。
使用CBAM注意力模块的示例代码:
``` python
import torch
import torch.nn as nn
from cbam import CBAM
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.cbam1 = CBAM(64)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.cbam2 = CBAM(64)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
self.cbam3 = CBAM(128)
self.conv4 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
self.cbam4 = CBAM(128)
self.fc = nn.Linear(128 * 8 * 8, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = self.cbam1(x)
x = F.max_pool2d(x, 2)
x = F.relu(self.conv2(x))
x = self.cbam2(x)
x = F.max_pool2d(x, 2)
x = F.relu(self.conv3(x))
x = self.cbam3(x)
x = F.max_pool2d(x, 2)
x = F.relu(self.conv4(x))
x = self.cbam4(x)
x = F.max_pool2d(x, 2)
x = x.view(-1, 128 * 8 * 8)
x = self.fc(x)
return x
```
在上面的代码中,我们定义了一个简单的CNN模型,并在每个卷积层后使用了CBAM注意力模块。注意力模块可以帮助CNN模型学习更好的特征表示,从而提高模型的准确性。