CoordinateAttention加入resnet18的代码
时间: 2023-07-02 09:04:32 浏览: 190
以下是将 Coordinate Attention 应用于 ResNet18 的代码示例:
```python
import torch
import torch.nn as nn
from torch.nn import Module
from torch.nn.parameter import Parameter
from torch.nn.functional import conv2d
class CoordinateAttention(Module):
def __init__(self, in_channels):
super(CoordinateAttention, self).__init__()
self.in_channels = in_channels
self.gamma = Parameter(torch.zeros(1))
def forward(self, x):
batch_size, _, height, width = x.size()
location = torch.stack(torch.meshgrid([torch.linspace(-1, 1, height), torch.linspace(-1, 1, width)]), 2).to(x.device).float()
weight = 1 + self.gamma * torch.exp(-torch.sum(torch.square(location), dim=2, keepdim=True))
return x * weight
class ResNet18_CA(nn.Module):
def __init__(self, num_classes=1000):
super(ResNet18_CA, self).__init__()
self.num_classes = num_classes
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.relu = nn.ReLU(inplace=True)
self.ca1 = CoordinateAttention(64)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = nn.Sequential(
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
CoordinateAttention(64)
)
self.layer2 = nn.Sequential(
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
CoordinateAttention(128)
)
self.layer3 = nn.Sequential(
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
CoordinateAttention(256)
)
self.layer4 = nn.Sequential(
nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
CoordinateAttention(512)
)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512, num_classes)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.ca1(x)
x = self.maxpool(x)
x = self.layer1(x) + x
x = self.layer2(x) + conv2d(x, self.layer2[0].weight, bias=None, stride=2, padding=1)
x = self.layer3(x) + conv2d(x, self.layer3[0].weight, bias=None, stride=2, padding=1)
x = self.layer4(x) + conv2d(x, self.layer4[0].weight, bias=None, stride=2, padding=1)
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.fc(x)
return x
```
其中,`CoordinateAttention` 是实现了 Coordinate Attention 模块的类,`ResNet18_CA` 则是基于 ResNet18 修改后的包含 Coordinate Attention 模块的模型。在 `ResNet18_CA` 中,利用 `nn.Sequential` 将每个 stage 中的卷积和 BN 层以及 Coordinate Attention 模块组合在一起。在 `forward` 函数中,每个 stage 的输出都要加上一个残差连接(Residual Connection)。最后将输出通过全局平均池化(`nn.AdaptiveAvgPool2d`)和一个全连接层进行分类预测。
阅读全文