super(ResNet_CD, self).__init__() expand = 1什么意思
时间: 2024-05-22 19:10:37 浏览: 9
`super(ResNet_CD, self).__init__()` 是调用父类 `ResNet_CD` 的构造函数,即初始化父类。
`expand = 1` 是一个变量赋值语句,将值1赋给变量 `expand`。在深度学习中,`expand` 一般用于控制卷积核的扩张倍数,即输出通道数与输入通道数之间的比例关系。如果 `expand=1`,表示输出通道数等于输入通道数,即不进行扩张。如果 `expand>1`,表示输出通道数大于输入通道数,即进行扩张。
相关问题
更改import torch import torchvision.models as models import torch.nn as nn import torch.nn.functional as F class eca_Resnet50(nn.Module): def init(self): super().init() self.model = models.resnet50(pretrained=True) self.model.avgpool = nn.AdaptiveAvgPool2d((1,1)) self.model.fc = nn.Linear(2048, 1000) self.eca = ECA_Module(2048, 8) def forward(self, x): x = self.model.conv1(x) x = self.model.bn1(x) x = self.model.relu(x) x = self.model.maxpool(x) x = self.model.layer1(x) x = self.model.layer2(x) x = self.model.layer3(x) x = self.model.layer4(x) x = self.eca(x) x = self.model.avgpool(x) x = torch.flatten(x, 1) x = self.model.fc(x) return x class ECA_Module(nn.Module): def init(self, channel, k_size=3): super(ECA_Module, self).init() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): b, c, _, _ = x.size() y = self.avg_pool(x) y = self.conv(y.squeeze(-1).transpose(-1,-2)).transpose(-1,-2).unsqueeze(-1) y = self.sigmoid(y) return x * y.expand_as(x) class ImageDenoising(nn.Module): def init(self): super().init() self.model = eca_Resnet50() self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1) self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1) self.conv3 = nn.Conv2d(64, 3, kernel_size=3, stride=1, padding=1) def forward(self, x): x = self.conv1(x) x = F.relu(x) x = self.conv2(x) x = F.relu(x) x = self.conv3(x) x = F.relu(x) return x,使最后输出为[16,1,50,50,]。
更改后的代码如下:
```
import torch
import torchvision.models as models
import torch.nn as nn
import torch.nn.functional as F
class eca_Resnet50(nn.Module):
def __init__(self):
super().__init__()
self.model = models.resnet50(pretrained=True)
self.model.avgpool = nn.AdaptiveAvgPool2d((1,1))
self.model.fc = nn.Linear(2048, 1000)
self.eca = ECA_Module(2048, 8)
def forward(self, x):
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)
x = self.model.layer1(x)
x = self.model.layer2(x)
x = self.model.layer3(x)
x = self.model.layer4(x)
x = self.eca(x)
x = self.model.avgpool(x)
x = torch.flatten(x, 1)
x = self.model.fc(x)
return x
class ECA_Module(nn.Module):
def __init__(self, channel, k_size=3):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x)
y = self.conv(y.squeeze(-1).transpose(-1,-2)).transpose(-1,-2).unsqueeze(-1)
y = self.sigmoid(y)
return x * y.expand_as(x)
class ImageDenoising(nn.Module):
def __init__(self):
super().__init__()
self.model = eca_Resnet50()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
self.conv3 = nn.Conv2d(64, 1, kernel_size=3, stride=1, padding=1)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = self.conv3(x)
x = F.relu(x)
x = x.view(-1, 1, 50, 50)
return x
```
在修改的过程中,主要是将最后一层卷积层的输出通道数改为 1,并在 forward 方法中添加了 x.view(-1, 1, 50, 50),将输出的大小修改为 [16,1,50,50]。
class MyNet(nn.Module): def __init__(self): super(MyNet, self).__init__() self.vgg16 = vgg16(pretrained=True) self.resnet18 = resnet18(pretrained=True) self.vgg16.classifier = nn.Identity() self.resnet18.fc = nn.Identity() self.fc = nn.Linear(25600, 2) def forward(self, x): x1 = self.vgg16(x) x2 = self.resnet18(x) x1 = x1.view(x1.size(0), -1) x2 = x2.view(x2.size(0), -1) x = torch.cat((x1, x2), dim=1) x = self.fc(x) return x 将以上代码加入DANet注意力机制
可以将 DAnet 注意力机制加入到 MyNet 的 forward 函数中,如下所示:
```python
import torch
import torch.nn as nn
class DAnet(nn.Module):
def __init__(self, in_channels, out_channels, reduction=16):
super(DAnet, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.conv1 = nn.Conv2d(in_channels, in_channels // reduction, kernel_size=1, stride=1, padding=0)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(in_channels // reduction, out_channels, kernel_size=1, stride=1, padding=0)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x)
y = self.conv1(y)
y = self.relu(y)
y = self.conv2(y)
y = self.sigmoid(y)
return x * y.expand_as(x)
class MyNet(nn.Module):
def __init__(self):
super(MyNet, self).__init__()
self.vgg16 = vgg16(pretrained=True)
self.resnet18 = resnet18(pretrained=True)
self.vgg16.classifier = nn.Identity()
self.resnet18.fc = nn.Identity()
self.danet1 = DAnet(512, 512)
self.danet2 = DAnet(512, 512)
self.fc = nn.Linear(25600, 2)
def forward(self, x):
x1 = self.vgg16.features(x)
x1 = self.danet1(x1)
x1 = self.vgg16.avgpool(x1)
x1 = torch.flatten(x1, 1)
x2 = self.resnet18.conv1(x)
x2 = self.resnet18.bn1(x2)
x2 = self.resnet18.relu(x2)
x2 = self.resnet18.maxpool(x2)
x2 = self.resnet18.layer1(x2)
x2 = self.resnet18.layer2(x2)
x2 = self.resnet18.layer3(x2)
x2 = self.resnet18.layer4(x2)
x2 = self.danet2(x2)
x2 = self.resnet18.avgpool(x2)
x2 = torch.flatten(x2, 1)
x = torch.cat((x1, x2), dim=1)
x = self.fc(x)
return x
```
在 forward 函数中,我们首先对 VGG16 和 ResNet18 的特征提取部分进行计算,然后将 VGG16 的特征图 x1 和 ResNet18 的特征图 x2 分别输入到两个 DAnet 注意力机制中进行加权,最后将加权后的两个特征图拼接在一起,并通过全连接层进行分类。
相关推荐
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)