pytorch对抗网络日月光华
时间: 2024-12-31 21:41:49 浏览: 7
### 使用PyTorch实现对抗生成网络 (GAN)
#### 构建生成器
生成器的任务是从随机噪声向量生成逼真的图像。为了构建一个简单的生成器,可以采用全连接层或多层感知机结构。
```python
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, input_size=100, output_size=784, hidden_dim=128):
super(Generator, self).__init__()
# 定义线性层和激活函数
self.fc1 = nn.Linear(input_size, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim * 2)
self.fc3 = nn.Linear(hidden_dim * 2, output_size)
def forward(self, x):
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = torch.tanh(self.fc3(x)) # 输出范围[-1, 1]
return x
```
此代码定义了一个三层的生成器模型[^1]。
#### 设计判别器
判别器用于区分真实样本与由生成器产生的假样本之间的差异。通常情况下,判别器也是一个神经网络,它接收输入图片并预测其真实性概率值。
```python
class Discriminator(nn.Module):
def __init__(self, input_size=784, hidden_dim=128):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_dim*2)
self.fc2 = nn.Linear(hidden_dim*2, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, 1)
def forward(self, x):
x = nn.functional.leaky_relu(self.fc1(x), negative_slope=0.2)
x = nn.functional.leaky_relu(self.fc2(x), negative_slope=0.2)
x = torch.sigmoid(self.fc3(x)) # 将输出压缩到[0, 1]之间表示真假的概率
return x
```
这段代码实现了具有三个隐藏层的二分类器作为判别器。
#### 训练过程概述
训练过程中交替更新两个网络参数:
- **优化生成器**: 当前目标是最小化`log(1-D(G(z)))`,即让D尽可能认为G生成的数据是真的。
- **优化判别器**: 同时最大化对于实际数据的真实度评估以及最小化对伪造数据的真实性评分。
具体来说,在每次迭代中先固定住生成器权重来调整判别器;接着冻结判别器而仅改变生成器权值以提高欺骗能力。
#### 损失函数设置
损失函数的选择至关重要,这里采用了经典的交叉熵损失:
```python
criterion = nn.BCELoss()
real_label = 1.
fake_label = 0.
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))
optimizer_G = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
for epoch in range(num_epochs):
for i, data in enumerate(dataloader):
############################
# 更新 D 网络: maximize log(D(x)) + log(1 - D(G(z)))
###########################
optimizer_D.zero_grad()
real_images = data[0].to(device).view(-1, image_size) # 调整形状适应全连接层
label_real = torch.full((batch_size,), real_label, device=device)
output_real = discriminator(real_images.float()).view(-1)
errD_real = criterion(output_real, label_real)
errD_real.backward() # 只计算梯度不更新参数
noise_vector = torch.randn(batch_size, nz, device=device)
fake_images = generator(noise_vector)
label_fake = torch.full((batch_size,), fake_label, device=device)
output_fake = discriminator(fake_images.detach()).view(-1) # detach 防止反向传播至 G
errD_fake = criterion(output_fake, label_fake)
errD_fake.backward()
optimizer_D.step()
############################
# 更新 G 网络: minimize log(1 - D(G(z))), equivalent to maximizing log(D(G(z)))
###########################
optimizer_G.zero_grad()
label_g = torch.full((batch_size,), real_label, device=device) # 这里希望生成的结果被认为是真实的
output_g = discriminator(fake_images).view(-1)
errG = criterion(output_g, label_g)
errG.backward()
optimizer_G.step()
```
上述代码展示了完整的训练循环逻辑,其中包含了针对生成器和判别器各自的优化步骤。
阅读全文