计算机视觉中的注意力机制:从CNN到ViT的发展历程
发布时间: 2024-08-20 23:49:46 阅读量: 40 订阅数: 47
白色大气风格的建筑商业网站模板下载.rar
![计算机视觉中的注意力机制:从CNN到ViT的发展历程](https://www.sopitas.com/wp-content/uploads/2022/12/different-dimension-me-app-ia-convertir-foto-anime.png)
# 1. 计算机视觉中的注意力机制概述
计算机视觉中的注意力机制是一种受人类视觉系统启发的技术,它允许模型专注于图像或视频中重要的区域或特征。注意力机制通过分配权重来实现,这些权重表示每个区域或特征对最终决策的重要性。
注意力机制在计算机视觉中得到了广泛的应用,包括图像分类、目标检测和图像分割。它使模型能够更有效地从数据中学习,并提高其在复杂场景中的性能。注意力机制的类型包括基于CNN的注意力机制和基于ViT的注意力机制,它们分别利用卷积神经网络(CNN)和视觉Transformer(ViT)来计算注意力权重。
# 2. 基于 CNN 的注意力机制
在卷积神经网络(CNN)中,注意力机制通过赋予特征图中不同区域和通道不同的权重,来增强模型对重要特征的关注。基于 CNN 的注意力机制主要分为空间注意力机制和通道注意力机制。
### 2.1 空间注意力机制
空间注意力机制关注特征图中的空间维度,通过突出重要区域来增强模型对特定对象的识别。
#### 2.1.1 Squeeze-and-Excitation (SE) 模块
SE 模块是一种简单的空间注意力机制,它通过对特征图的通道维度进行全局池化和重新加权来增强特征图中的重要通道。
```python
import torch
import torch.nn as nn
class SqueezeAndExcitation(nn.Module):
def __init__(self, channels):
super().__init__()
self.fc1 = nn.Linear(channels, channels // 16)
self.fc2 = nn.Linear(channels // 16, channels)
def forward(self, x):
# 全局池化
avg_pool = torch.mean(x, dim=[2, 3])
# 压缩
z = self.fc1(avg_pool)
# 激励
z = F.relu(z)
# 扩张
z = self.fc2(z)
# 重新加权
x = x * z.unsqueeze(2).unsqueeze(3)
return x
```
**逻辑分析:**
* `avg_pool` 对特征图进行全局池化,得到一个通道维度的向量。
* `fc1` 和 `fc2` 组成两个全连接层,对通道维度的向量进行压缩和扩张。
* `F.relu` 函数对压缩后的向量进行 ReLU 激活。
* `unsqueeze(2).unsqueeze(3)` 操作将通道维度的向量扩展为与特征图相同的形状。
* `x * z.unsqueeze(2).unsqueeze(3)` 对特征图进行重新加权,赋予重要通道更大的权重。
#### 2.1.2 Channel Attention Module (CAM)
CAM 是一种更复杂的注意力机制,它通过对特征图的通道维度进行自注意力计算来增强特征图中的重要通道。
```python
import torch
import torch.nn as nn
class ChannelAttentionModule(nn.Module):
def __init__(self, channels):
super().__init__()
self.query = nn.Linear(channels, channels // 8)
self.key = nn.Linear(channels, channels // 8)
self.value = nn.Linear(channels, channels)
def forward(self, x):
# 查询、键、值计算
q = self.query(x).transpose(2, 3)
k = self.key(x)
v = self.value(x)
# 自注意力计算
a = torch.matmul(q, k) / (channels // 8) ** 0.5
a = F.softmax(a, dim=3)
# 加权求和
x = torch.matmul(a, v)
return x
```
**逻辑分析:**
* `query`、`key` 和 `value` 组成三个全连接层,分别对特征图的通道维度进行查询、键和值计算。
* `torch.matmul(q, k) / (channels // 8) ** 0.5` 计算查询和键之间的点积,并进行归一化。
* `F.softmax(a, dim=3)` 对点积结果进行 softmax 操作,得到通道维度的权重。
* `torch.matmul(a, v)` 对特征图的通道维度进行加权求和,增强重要通道的权重。
### 2.2 通道注意力机制
通道注意力机制关注特征图中的通道维度,通过突出重要通道来增强模型对特定特征的识别。
#### 2.2.1 Spatial Transformer Networks (STN)
STN 是一种基于空间变换的注意力机制,它通过对特征图进行仿射变换来对齐重要区域。
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class SpatialTransformer(nn.Module):
def __init__(self, channels):
super().__init__()
self.fc1 = nn.Linear(channels, channels * 2 * 3)
def forward(self, x):
# 仿射变换参数计算
theta = self.fc1(x).view(-1, 2, 3)
# 空间变换
grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid)
return x
```
**逻辑分析:**
* `fc1` 组成一个全连接层,对特征图的通道维度进行仿射变换参数计算。
* `view(-1, 2, 3)` 将仿射变换参数重塑为 [batch_size, 2, 3] 的形状。
* `F.affi
0
0