给出edgeconnect中使用places2数据集训练的InpaintingModel_gen.pth的原版上下文编解码器代码
时间: 2023-06-15 12:07:16 浏览: 174
以下是使用places2数据集训练的InpaintingModel_gen.pth的原版上下文编解码器代码:
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
class ContextualAttention(nn.Module):
def __init__(self, kernel_size=3, stride=1, rate=2, fuse_k=3, softmax_scale=10, fuse=True):
super(ContextualAttention, self).__init__()
self.kernel_size = kernel_size
self.stride = stride
self.rate = rate
self.fuse_k = fuse_k
self.softmax_scale = softmax_scale
self.fuse = fuse
self.padding = nn.ZeroPad2d(rate)
self.softmax = nn.Softmax(dim=3)
self.fuse_conv = nn.Conv2d(in_channels=2*self.kernel_size*self.kernel_size,
out_channels=self.fuse_k, kernel_size=1)
self.theta_conv = nn.Conv2d(in_channels=3, out_channels=self.kernel_size*self.kernel_size, kernel_size=1)
self.phi_conv = nn.Conv2d(in_channels=3, out_channels=self.kernel_size*self.kernel_size, kernel_size=1)
self.g_conv = nn.Conv2d(in_channels=3, out_channels=self.kernel_size*self.kernel_size, kernel_size=1)
self.o_conv = nn.Conv2d(in_channels=self.kernel_size*self.kernel_size, out_channels=3, kernel_size=1)
self.downsample = nn.AvgPool2d(3, stride=2, padding=[1, 1], count_include_pad=False)
def forward(self, content, mask=None):
# content: (b, c, h, w)
# mask: (b, 1, h, w)
if mask is None:
mask = torch.zeros_like(content[:, :1, :, :])
else:
mask = mask[:, :1, :, :]
b, c, h, w = content.size()
theta = self.theta_conv(content)
phi = self.phi_conv(content)
g = self.g_conv(content)
theta = theta.view(b, self.kernel_size*self.kernel_size, h*w)
theta = theta.permute(0, 2, 1)
phi = phi.view(b, self.kernel_size*self.kernel_size, h*w)
f = torch.matmul(phi, g)
f_div_C = self.softmax(f)
if self.fuse:
theta = theta.view(b, h, w, self.kernel_size*self.kernel_size)
theta = theta.permute(0, 3, 1, 2)
phi = phi.view(b, h, w, self.kernel_size*self.kernel_size)
phi = phi.permute(0, 3, 1, 2)
fuse = torch.cat([theta, phi], dim=1)
fuse = self.fuse_conv(fuse)
fuse = self.softmax(fuse*self.softmax_scale)
f_div_C = torch.matmul(fuse.view(b, self.fuse_k, h*w), f_div_C)
f_div_C = f_div_C.permute(0, 2, 1).contiguous()
y = torch.matmul(f_div_C, theta)
y = y.permute(0, 2, 1).contiguous()
y = y.view(b, self.kernel_size*self.kernel_size, h, w)
o = self.o_conv(y)
if self.stride > 1:
mask = self.downsample(mask)
o_mask = self.downsample(1-mask)
else:
o_mask = 1-mask
return (o*mask) + (content*o_mask)
class ResidualBlock(nn.Module):
def __init__(self, in_channels=64, ker_size=3, out_channels=64, rate=1):
super(ResidualBlock, self).__init__()
self.conv1 = nn.Conv2d(in_channels=in_channels, out_channels=out_channels,
kernel_size=ker_size, padding=rate, dilation=rate, bias=True)
self.conv2 = nn.Conv2d(in_channels=out_channels, out_channels=out_channels,
kernel_size=ker_size, padding=rate, dilation=rate, bias=True)
self.norm1 = nn.InstanceNorm2d(out_channels, affine=True)
self.norm2 = nn.InstanceNorm2d(out_channels, affine=True)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
residual = x
out = self.relu(self.norm1(self.conv1(x)))
out = self.norm2(self.conv2(out))
out = out + residual
return out
class InpaintGenerator(nn.Module):
def __init__(self, in_channels=4, out_channels=3, block_num=7):
super(InpaintGenerator, self).__init__()
self.downsample = nn.Sequential(
nn.Conv2d(in_channels, 32, 3, stride=1, padding=1),
nn.InstanceNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 64, 3, stride=2, padding=1),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 128, 3, stride=2, padding=1),
nn.InstanceNorm2d(128),
nn.ReLU(inplace=True)
)
self.residual_blocks = nn.Sequential(
*[ResidualBlock(in_channels=128, out_channels=128) for _ in range(block_num)]
)
self.upsample = nn.Sequential(
nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(64),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, 32, 3, stride=2, padding=1, output_padding=1),
nn.InstanceNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, out_channels, 3, stride=1, padding=1),
nn.Tanh()
)
self.contextual_attention = ContextualAttention()
def forward(self, x, mask):
x = torch.cat([x, mask], dim=1)
x = self.downsample(x)
x = self.residual_blocks(x)
x = self.upsample(x)
x = (torch.tanh(x)+1)/2
x = self.contextual_attention(x, mask)
return x
```
阅读全文