ConvNeXt复现
时间: 2025-01-09 08:19:03 浏览: 1
### 实现 ConvNeXt 模型
#### PyTorch 版本实现
为了在 PyTorch 中实现 ConvNeXt 模型,可以从构建基本模块开始。ConvNeXt 使用了改进的残差块结构以及 LayerScale 层来增强表现力。
```python
import torch.nn as nn
import torch
class ConvNextBlock(nn.Module):
r""" ConvNeXt Block. There are two equivalent implementations:
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
Implementation follows the official implementation which uses option (2).
Args:
dim (int): Number of input channels.
drop_path (float): Stochastic depth rate. Default: 0.0
layer_scale_init_value (float): Init value for Layer Scale. Default: 1e-6.
"""
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
if layer_scale_init_value > 0:
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True)
else:
self.gamma = None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
shortcut = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
x = shortcut + self.drop_path(x)
return x
```
这段代码展示了如何创建一个标准的 ConvNeXt 块,其中包括深度可分离卷积、层标准化处理和位置感知全连接层[^5]。
对于完整的模型架构,则需要堆叠多个这样的区块并加入下采样阶段:
```python
class ConvNeXt(nn.Module):
def __init__(self, in_chans=3, num_classes=1000,
depths=[3, 3, 9, 3], dims=[96, 192, 384, 768],
drop_path_rate=0.,
head_init_scale=1.,
):
super().__init__()
self.downsample_layers = nn.ModuleList() # stem and 3 intermediate downsampling conv layers
stem = nn.Sequential(
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
)
self.downsample_layers.append(stem)
for i in range(3):
downsample_layer = nn.Sequential(
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
nn.Conv2d(dims[i], dims[i+1], kernel_size=2, stride=2),
)
self.downsample_layers.append(downsample_layer)
self.stages = nn.ModuleList() # 4 feature resolution stages, each consisting of multiple residual blocks
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
for i in range(4):
stage = nn.Sequential(
*[ConvNextBlock(dim=dims[i], drop_path=dp_rates[cur + j]) for j in range(depths[i])]
)
self.stages.append(stage)
cur += depths[i]
self.norm = nn.LayerNorm(dims[-1], eps=1e-6) # final norm layer
self.head = nn.Linear(dims[-1], num_classes)
self.apply(self._init_weights)
self.head.weight.data.mul_(head_init_scale)
self.head.bias.data.mul_(head_init_scale)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward_features(self, x):
for i in range(4):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
return self.norm(x.mean([-2, -1])) # global average pooling, (N, C, H, W) -> (N, C)
def forward(self, x):
x = self.forward_features(x)
x = self.head(x)
return x
```
上述代码实现了整个 ConvNeXt 架构,在这里定义了一个具有四个不同分辨率级别的特征提取器,并通过全局平均池化和平滑线性分类头完成最终预测任务。
#### TensorFlow/Keras 版本实现
TensorFlow 和 Keras 提供了一种更高级别的 API 来快速搭建神经网络模型。下面是一个简单的例子展示如何用 Keras 创建类似的 ConvNeXt 结构:
```python
from tensorflow.keras.layers import Input, Conv2D, BatchNormalization, ReLU, Add, GlobalAveragePooling2D, Dense
from tensorflow.keras.models import Model
def convnext_block(inputs, filters, name=None):
"""A single block within a ConvNeXt model."""
x = inputs
res = x
# Depth-wise convolution
x = Conv2D(filters=filters, kernel_size=(7, 7), strides=(1, 1), padding='same', use_bias=False)(x)
x = BatchNormalization()(x)
# Point-wise expansion & contraction
expanded_filters = int(filters * 4)
x = Conv2D(expanded_filters, kernel_size=(1, 1))(x)
x = ReLU()(x)
x = Conv2D(filters, kernel_size=(1, 1))(x)
# Residual connection
out = Add(name=name)([res, x])
return out
input_shape = (224, 224, 3)
num_classes = 1000
inputs = Input(shape=input_shape)
stem = Conv2D(96, kernel_size=(4, 4), strides=(4, 4))(inputs)
# Define your own number of blocks per stage here...
stage_1 = convnext_block(stem, 96,
阅读全文