cyclegan 生成器,鉴别器
时间: 2024-12-31 13:19:59 浏览: 17
### CycleGAN 的生成器和鉴别器工作原理
#### 生成器的工作机制
CycleGAN 中的生成器旨在将一种域的数据映射到另一种域,而无需成对的数据集。具体来说,存在两个生成器 \( G \) 和 \( F \),其中:
- **\( G(X): X → Y \)** 将来自源域 \( X \) 的图像转换为目标域 \( Y \)[^2]。
- **\( F(Y): Y → X \)** 则执行相反的操作,即将目标域 \( Y \) 的图像转换回源域 \( X \)。
为了确保这种双向映射的有效性和一致性,引入了一个循环一致损失 (cycle consistency loss) 来约束整个过程。这意味着如果先应用 \( G \) 后再应用 \( F \),最终得到的结果应该尽可能接近原始输入图像。这有助于保持图像内容的一致性并防止信息丢失[^3]。
```python
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, input_nc, output_nc, ngf=64, n_blocks=9):
super(Generator, self).__init__()
model = [
nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=False),
nn.InstanceNorm2d(ngf),
nn.ReLU(True)
]
# 下采样部分...
for i in range(n_downsampling):
mult = 2**i
model += [
nn.Conv2d(ngf * mult, ngf * mult * 2,
kernel_size=3, stride=2, padding=1, bias=False),
nn.InstanceNorm2d(ngf * mult * 2),
nn.ReLU(True)
]
# ResNet 块...
mult = 2**n_downsampling
for i in range(n_blocks):
model += [ResnetBlock(ngf * mult)]
# 上采样部分...
for i in range(n_downsampling):
mult = 2**(n_downsampling - i)
model += [
nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
kernel_size=3, stride=2,
padding=1, output_padding=1,
bias=False),
nn.InstanceNorm2d(int(ngf * mult / 2)),
nn.ReLU(True)
]
model += [
nn.ReflectionPad2d(3),
nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
nn.Tanh()
]
self.model = nn.Sequential(*model)
def forward(self, x):
return self.model(x)
class ResnetBlock(nn.Module):
"""定义残差块"""
def __init__(self, dim, use_dropout=False, norm_layer=nn.BatchNorm2d):
super().__init__()
conv_block = []
p = 0
if isinstance(norm_layer, nn.BatchNorm2d):
raise NotImplementedError('BatchNorm not implemented')
elif isinstance(norm_layer, nn.InstanceNorm2d):
p = 1
conv_block += [
nn.ReflectionPad2d(p),
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=True),
norm_layer(dim),
nn.ReLU(True)]
if use_dropout:
conv_block += [nn.Dropout(0.5)]
conv_block += [
nn.ReflectionPad2d(p),
nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=True),
norm_layer(dim)]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x):
out = x + self.conv_block(x)
return out
```
#### 鉴别器的工作机制
鉴别器的任务是区分真实样本与由生成器产生的伪造样本。在 CycleGAN 架构下,采用了 PatchGAN 设计理念来构建鉴别器。PatchGAN 并不是预测整张图片的真实性,而是输出一个小窗口内的局部真实性评估矩阵。这种方式不仅提高了效率还增强了细节捕捉能力。
```python
class Discriminator(nn.Module):
def __init__(self, input_nc, ndf=64, n_layers=3):
super(Discriminator, self).__init__()
kw = 4
padw = 1
sequence = [
nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
nn.LeakyReLU(0.2, True)
]
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult_prev = nf_mult
nf_mult = min(2 ** n, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=2, padding=padw, bias=False),
nn.InstanceNorm2d(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
nf_mult_prev = nf_mult
nf_mult = min(2 ** n_layers, 8)
sequence += [
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult,
kernel_size=kw, stride=1, padding=padw, bias=False),
nn.InstanceNorm2d(ndf * nf_mult),
nn.LeakyReLU(0.2, True)
]
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)]
self.model = nn.Sequential(*sequence)
def forward(self, input):
return self.model(input)
```
阅读全文