yolov7重参数化卷积(RepConv)
时间: 2025-01-02 19:42:41 浏览: 13
### YOLOv7 中 RepConv 的工作原理
RepConv 是一种模型重参数化技术,其核心思想在于训练过程中采用复杂的多分支结构以提升表达能力,而在推理阶段则将这些复杂结构简化为单一的等效卷积操作。这种设计不仅保持了高性能,还显著降低了实际部署中的计算开销。
#### 训练过程
在训练期间,RepConv 使用多个并行路径来处理输入数据,通常包括但不限于:
- 主要路径上的标准 3×3 或 1×1 卷积层;
- 可选的身份映射(Identity Mapping),即直接传递未经修改的数据流;
- 辅助分支可能涉及不同大小或类型的滤波器组合[^2]。
通过这种方式,网络能够在更丰富的空间内探索最优解,进而获得更好的泛化能力和更高的准确性。
#### 推理优化
当进入推理模式时,所有上述提到的不同路径会被融合成单一层——通常是简单的 3×3 卷积加上批标准化(Batch Normalization, BN)。这一转换基于以下原则:
- **权重合并**:对于每一个辅助分支,将其对应的卷积核与BN层参数相乘得到最终的有效权重矩阵。
- **偏置调整**:同样地,各分支产生的偏移项也需要累加起来形成新的整体偏置向量。
经过这样的重构后,原本由多条线路构成的复杂子网就变成了易于执行的标准组件,极大地提高了运行速度和资源利用率。
```python
import torch.nn as nn
class RepConv(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, groups=1, deploy=False):
super(RepConv, self).__init__()
self.deploy = deploy
if not self.deploy:
# Training mode with multiple branches
self.branches = nn.ModuleList([
nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=(kernel_size, kernel_size), stride=stride, padding=padding, bias=True),
nn.BatchNorm2d(num_features=out_channels)])
if kernel_size != 1:
self.branches.append(
nn.Sequential(
nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=(1, 1), stride=stride, padding=0, bias=True),
nn.BatchNorm2d(out_channels)))
if stride == 1 and in_channels == out_channels:
self.branches.append(nn.Identity())
else:
# Inference mode using a single conv layer
self.conv = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=(kernel_size, kernel_size), stride=stride, padding=padding)
def forward(self, x):
if hasattr(self, 'conv'):
return self.conv(x)
else:
result = sum([branch(x) for branch in self.branches])
return result
def switch_to_deploy(self):
"""Converts the model from training to deployment."""
if not self.deploy:
weight_main_conv = self.branches[0].weight.data.clone()
running_mean_bn = self.branches[1].running_mean.data.clone()
running_var_bn = self.branches[1].running_var.data.clone()
gamma_bn = self.branches[1].weight.data.clone()
beta_bn = self.branches[1].bias.data.clone()
std = (running_var_bn + 1e-5).sqrt()
t = (gamma_bn / std).reshape(-1, 1, 1, 1)
merged_weight = weight_main_conv * t
if len(self.branches)>2:
extra_branch_weights = []
for i in range(len(self.branches)-2):
w_extra = self.branches[i+2][0].weight.data.clone()
bn_running_mean = self.branches[i+2][1].running_mean.data.clone()
bn_running_var = self.branches[i+2][1].running_var.data.clone()
bn_gamma = self.branches[i+2][1].weight.data.clone()
bn_beta = self.branches[i+2][1].bias.data.clone()
std_extra = (bn_running_var + 1e-5).sqrt()
t_extra = (bn_gamma / std_extra).reshape(-1, 1, 1, 1)
extra_branch_weights.append(w_extra*t_extra)
final_merged_weight = merged_weight.sum(dim=0)+sum(extra_branch_weights)
else:
final_merged_weight = merged_weight
fused_bias = ((beta_bn-running_mean_bn)*t).view(-1,)
if isinstance(self.branches[-1],nn.Identity):
identity_input_shape = list(x.size())[1:]
eye_matrix = torch.eye(identity_input_shape[0]).unsqueeze(-1).unsqueeze(-1)
identity_weight = eye_matrix.to(weight_main_conv.device)
final_merged_weight += identity_weight*final_merged_weight.mean().item()/identity_input_shape[0]
new_conv_layer = nn.Conv2d(in_channels=self.branches[0].in_channels,
out_channels=self.branches[0].out_channels,
kernel_size=self.branches[0].kernel_size,
stride=self.branches[0].stride,
padding=self.branches[0].padding,
dilation=self.branches[0].dilation,
groups=self.branches[0].groups,
bias=True)
new_conv_layer.weight.data = final_merged_weight
new_conv_layer.bias.data = fused_bias
delattr(self,'branches')
setattr(self,'conv',new_conv_layer)
self.deploy = True
```
阅读全文