mmsegmentation中deeplabv3plus网络加入SKNet
时间: 2024-02-06 15:03:41 浏览: 100
mmsegmentation 是一款基于 PyTorch 开发的图像分割工具库,提供了多种图像分割网络的实现,并且支持多GPU训练和推断。其中,deeplabv3plus 是一种先进的语义分割网络,具有较高的精度和鲁棒性。
SKNet(Selective Kernel Networks)是一种新型的卷积神经网络结构,通过选择性地应用不同大小的卷积核,可以有效地提高网络的表达能力和精度,在多个计算机视觉任务中都得到了良好的效果。
在 mmsegmentation 中,将 SKNet 应用于 deeplabv3plus 网络,可以进一步提高分割精度,特别是对于细节和纹理丰富的图像,效果更为明显。此外,由于 SKNet 的高效性能,可以在不增加过多计算量的情况下提高模型的表现。
相关问题
mmsegmentation中deeplabv3plus网络加入SKNet具体步骤
在 mmsegmentation 中将 SKNet 加入 Deeplabv3+ 网络的步骤如下:
1. 在 mmseg/models/backbones 中新建 sknet.py 文件,定义 SKNet 的网络结构。
2. 在 sknet.py 文件中实现 SKNet 的基本模块,包括 SKUnit 和 SKNet。
3. 在 mmseg/models/backbones/deeplabv3plus.py 中导入 sknet.py 文件,并在 DeepLabV3Plus 的 backbone 中使用 SKNet。
4. 在 mmseg/datasets/pipelines/transforms.py 中增加支持 SKNet 的预处理方法。
5. 在 mmseg/core/evaluation/evaluators.py 中增加支持 SKNet 的评估方法。
6. 在 mmseg/apis/train.py 中设置训练参数,包括 SKNet 的学习率、权重衰减等。
7. 在 mmseg/apis/inference.py 中设置推理参数,包括 SKNet 的阈值等。
8. 在 mmseg/core/optimizer/builder.py 中定义 SKNet 优化器。
9. 在 mmseg/core/optimizer/scheduler.py 中定义 SKNet 的学习率调度器。
10. 在 mmseg/runner/checkpoint.py 中增加支持 SKNet 的模型保存和加载方法。
需要注意的是,以上步骤只是大致的流程,具体实现可能会因为 SKNet 的版本、mmseg 的版本等各种因素而有所不同。建议在实现过程中仔细查看 mmseg 的代码结构和 SKNet 的实现原理,以保证代码的正确性和可靠性。
mmsegmentation框架下给mobilenetv2主干加入SKNet
在`mmsegmentation`框架中,可以通过修改`mmseg/models/backbones/mobilenet_v2.py`文件来给`MobileNetV2`主干加入`SKNet`。下面是实现的步骤:
1. 首先需要导入`SKNet`的相关模块,可以在文件最上面添加以下代码:
```python
from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer
from ..builder import BACKBONES
```
2. 接着在`MobileNetV2`类中定义`SKNet`的结构,可以在`_make_stem_layer`函数中加入以下代码:
```python
def _make_stem_layer(self, in_channels, stem_channels):
layers = []
layers.append(ConvModule(
in_channels,
stem_channels,
3,
stride=2,
padding=1,
bias=False,
norm_cfg=dict(type='BN', momentum=0.1, eps=1e-5),
activation='relu',
inplace=True))
in_channels = stem_channels
layers.append(ConvModule(
in_channels,
in_channels,
3,
stride=1,
padding=1,
bias=False,
norm_cfg=dict(type='BN', momentum=0.1, eps=1e-5),
activation='relu',
inplace=True))
# add SKNet module
channels = in_channels
mid_channels = channels // 2
squeeze_channels = max(1, mid_channels // 8)
layers.append(
build_plugin_layer(dict(
type='SKConv',
channels=channels,
squeeze_channels=squeeze_channels,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=32,
sk_mode='two',
norm_cfg=dict(type='BN', momentum=0.1, eps=1e-5),
act_cfg=dict(type='ReLU', inplace=True),
),
[build_conv_layer(
dict(type='Conv2d'),
channels,
channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(dict(type='BN', momentum=0.1, eps=1e-5), channels)[1]]))
return nn.Sequential(*layers)
```
3. 最后在`BACKBONES`中注册`MobileNetV2`主干即可。完整代码如下:
```python
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmcv.cnn import build_conv_layer, build_norm_layer, build_plugin_layer
from ..builder import BACKBONES
@BACKBONES.register_module()
class MobileNetV2(nn.Module):
def __init__(self,
widen_factor=1.0,
output_stride=32,
norm_cfg=dict(type='BN', momentum=0.1, eps=1e-5),
with_cp=False,
):
super(MobileNetV2, self).__init__()
assert output_stride in [8, 16, 32]
self.output_stride = output_stride
self.with_cp = with_cp
self.norm_cfg = norm_cfg
input_channel = int(32 * widen_factor)
self.stem = self._make_stem_layer(3, input_channel)
self.layer1 = self._make_layer(
input_channel, int(16 * widen_factor), 1, 1, 16, 2)
self.layer2 = self._make_layer(
int(16 * widen_factor), int(24 * widen_factor), 2, 6, 16, 2)
self.layer3 = self._make_layer(
int(24 * widen_factor), int(32 * widen_factor), 3, 6, 24, 2)
self.layer4 = self._make_layer(
int(32 * widen_factor), int(64 * widen_factor), 4, 6, 32, 2)
self.layer5 = self._make_layer(
int(64 * widen_factor), int(96 * widen_factor), 3, 6, 64, 1)
self.layer6 = self._make_layer(
int(96 * widen_factor), int(160 * widen_factor), 3, 6, 96, 1)
self.layer7 = self._make_layer(
int(160 * widen_factor), int(320 * widen_factor), 1, 6, 160, 1)
if self.output_stride == 8:
self.layer2[0].conv2.stride = (1, 1)
self.layer2[0].downsample[0].stride = (1, 1)
self.layer3[0].conv2.stride = (1, 1)
self.layer3[0].downsample[0].stride = (1, 1)
elif self.output_stride == 16:
self.layer3[0].conv2.stride = (1, 1)
self.layer3[0].downsample[0].stride = (1, 1)
self._freeze_stages()
def _make_stem_layer(self, in_channels, stem_channels):
layers = []
layers.append(ConvModule(
in_channels,
stem_channels,
3,
stride=2,
padding=1,
bias=False,
norm_cfg=dict(type='BN', momentum=0.1, eps=1e-5),
activation='relu',
inplace=True))
in_channels = stem_channels
layers.append(ConvModule(
in_channels,
in_channels,
3,
stride=1,
padding=1,
bias=False,
norm_cfg=dict(type='BN', momentum=0.1, eps=1e-5),
activation='relu',
inplace=True))
# add SKNet module
channels = in_channels
mid_channels = channels // 2
squeeze_channels = max(1, mid_channels // 8)
layers.append(
build_plugin_layer(dict(
type='SKConv',
channels=channels,
squeeze_channels=squeeze_channels,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
groups=32,
sk_mode='two',
norm_cfg=dict(type='BN', momentum=0.1, eps=1e-5),
act_cfg=dict(type='ReLU', inplace=True),
),
[build_conv_layer(
dict(type='Conv2d'),
channels,
channels,
kernel_size=3,
stride=1,
padding=1,
bias=False),
build_norm_layer(dict(type='BN', momentum=0.1, eps=1e-5), channels)[1]]))
return nn.Sequential(*layers)
def _make_layer(self,
in_channels,
out_channels,
num_blocks,
stride,
t,
dilation=1):
layers = []
layers.append(InvertedResidual(
in_channels,
out_channels,
stride,
expand_ratio=t,
dilation=dilation,
norm_cfg=self.norm_cfg,
with_cp=self.with_cp))
for i in range(1, num_blocks):
layers.append(InvertedResidual(
out_channels,
out_channels,
1,
expand_ratio=t,
dilation=dilation,
norm_cfg=self.norm_cfg,
with_cp=self.with_cp))
return nn.Sequential(*layers)
def _freeze_stages(self):
if self.output_stride == 8:
self.stem.eval()
for param in self.stem.parameters():
param.requires_grad = False
elif self.output_stride == 16:
self.stem[0].eval()
self.stem[1].eval()
for param in self.stem[0].parameters():
param.requires_grad = False
for param in self.stem[1].parameters():
param.requires_grad = False
```
这样就给`MobileNetV2`主干加入了`SKNet`结构。
阅读全文