包含ResRep的pruning_filters源码pytorch版本源码
时间: 2024-04-30 14:20:27 浏览: 151
以下是PyTorch版本的pruning_filters源码(包含ResRep):
```python
import torch
import torch.nn as nn
import torch.nn.utils.prune as prune
class PruningFilter(nn.Module):
def __init__(self, module, name):
super(PruningFilter, self).__init__()
self.module = module
self.name = name
# Initialize the filter mask
self.filter_mask = nn.Parameter(torch.ones(module.weight.shape[0]), requires_grad=False)
def forward(self, x):
# Apply the mask to the weight tensor
weight = self.module.weight * self.filter_mask.view(-1, 1, 1, 1)
return nn.functional.conv2d(x, weight, self.module.bias, self.module.stride, self.module.padding,
self.module.dilation, self.module.groups)
def compute_mask(self, prune_method, **prune_kwargs):
# Compute the mask for the given pruning method and arguments
prune_method(self.module, name=self.name, **prune_kwargs)
self.filter_mask.data = prune.identity(self.module.weight, name=self.name).bool().sum(dim=(1, 2, 3)) != 0
class PruningResRep(nn.Module):
def __init__(self, num_classes=1000):
super(PruningResRep, self).__init__()
# Define the ResRep model
self.features = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.fc = nn.Linear(512, num_classes)
# Initialize the pruning filters
self.filter1 = PruningFilter(self.features[0], 'weight')
self.filter2 = PruningFilter(self.features[3], 'weight')
self.filter3 = PruningFilter(self.features[6], 'weight')
self.filter4 = PruningFilter(self.features[10], 'weight')
self.filter5 = PruningFilter(self.features[13], 'weight')
self.filter6 = PruningFilter(self.features[16], 'weight')
self.filter7 = PruningFilter(self.features[20], 'weight')
self.filter8 = PruningFilter(self.features[23], 'weight')
self.filter9 = PruningFilter(self.features[26], 'weight')
def forward(self, x):
x = self.filter1(x)
x = self.features[1:](x)
x = self.filter2(x)
x = self.features[4:](x)
x = self.filter3(x)
x = self.features[7:](x)
x = self.filter4(x)
x = self.features[11:](x)
x = self.filter5(x)
x = self.features[14:](x)
x = self.filter6(x)
x = self.features[17:](x)
x = self.filter7(x)
x = self.features[21:](x)
x = self.filter8(x)
x = self.features[24:](x)
x = self.filter9(x)
x = self.features[27:](x)
x = self.avgpool(x)
x = x.view(x.size(0), -1)
x = self.fc(x)
return x
def compute_mask(self, prune_method, **prune_kwargs):
# Compute the mask for each pruning filter
self.filter1.compute_mask(prune_method, **prune_kwargs)
self.filter2.compute_mask(prune_method, **prune_kwargs)
self.filter3.compute_mask(prune_method, **prune_kwargs)
self.filter4.compute_mask(prune_method, **prune_kwargs)
self.filter5.compute_mask(prune_method, **prune_kwargs)
self.filter6.compute_mask(prune_method, **prune_kwargs)
self.filter7.compute_mask(prune_method, **prune_kwargs)
self.filter8.compute_mask(prune_method, **prune_kwargs)
self.filter9.compute_mask(prune_method, **prune_kwargs)
```
其中,`PruningFilter` 是用于对卷积层进行剪枝的模块,`PruningResRep` 是ResRep模型的剪枝版本。在 `PruningResRep` 中,我们将每个卷积层都封装成了一个 `PruningFilter` 模块,并在 `compute_mask` 方法中调用每个 `PruningFilter` 模块的 `compute_mask` 方法来计算每个卷积层的剪枝掩码。
阅读全文