给出edgeconnect中使用places2数据集训练的InpaintingModel_gen.pth的原版上下文编码器代码
时间: 2023-06-15 20:07:38 浏览: 113
以下是edgeconnect中使用places2数据集训练的InpaintingModel_gen.pth的原版上下文编码器代码:
```python
import torch.nn as nn
import torch.nn.functional as F
class Context_Encoder(nn.Module):
def __init__(self):
super(Context_Encoder, self).__init__()
self.conv1 = nn.Conv2d(4, 64, kernel_size=4, stride=2, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
self.conv3 = nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)
self.conv4 = nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1)
self.conv5 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
self.conv6 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
self.conv7 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
self.conv8 = nn.Conv2d(512, 512, kernel_size=4, stride=2, padding=1)
self.norm1 = nn.InstanceNorm2d(64, affine=True)
self.norm2 = nn.InstanceNorm2d(128, affine=True)
self.norm3 = nn.InstanceNorm2d(256, affine=True)
self.norm4 = nn.InstanceNorm2d(512, affine=True)
self.norm5 = nn.InstanceNorm2d(512, affine=True)
self.norm6 = nn.InstanceNorm2d(512, affine=True)
self.norm7 = nn.InstanceNorm2d(512, affine=True)
self.norm8 = nn.InstanceNorm2d(512, affine=True)
self.relu = nn.ReLU()
def forward(self, x):
conv1 = self.conv1(x)
norm1 = self.norm1(conv1)
relu1 = self.relu(norm1)
conv2 = self.conv2(relu1)
norm2 = self.norm2(conv2)
relu2 = self.relu(norm2)
conv3 = self.conv3(relu2)
norm3 = self.norm3(conv3)
relu3 = self.relu(norm3)
conv4 = self.conv4(relu3)
norm4 = self.norm4(conv4)
relu4 = self.relu(norm4)
conv5 = self.conv5(relu4)
norm5 = self.norm5(conv5)
relu5 = self.relu(norm5)
conv6 = self.conv6(relu5)
norm6 = self.norm6(conv6)
relu6 = self.relu(norm6)
conv7 = self.conv7(relu6)
norm7 = self.norm7(conv7)
relu7 = self.relu(norm7)
conv8 = self.conv8(relu7)
norm8 = self.norm8(conv8)
relu8 = self.relu(norm8)
return relu8
```
阅读全文