GAN生成对抗网络 手写数字
时间: 2024-12-30 12:26:32 浏览: 15
### 使用 GAN 生成对抗网络生成手写数字 MNIST 数据集
#### 定义生成器和判别器模型
为了构建一个能够生成手写数字的GAN,需要设计两个主要组件:生成器(Generator)和判别器(Discriminator)[^1]。生成器负责创建看起来像真实的手写数字的新图像;而判别器则尝试区分这些合成图像与实际来自MNIST数据集的真实图像。
```python
import torch
from torch import nn, optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
class Generator(nn.Module):
def __init__(self, input_dim=100, output_dim=784):
super().__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 256),
nn.ReLU(),
nn.Linear(256, output_dim),
nn.Tanh()
)
def forward(self, z):
img = self.model(z)
return img.view(-1, 1, 28, 28)
class Discriminator(nn.Module):
def __init__(self, input_dim=784):
super().__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, img):
flattened_img = img.view(img.size(0), -1)
prob_real = self.model(flattened_img)
return prob_real.squeeze()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
gen = Generator().to(device)
dis = Discriminator().to(device)
```
#### 设置损失函数和优化器
对于训练过程来说,设定合适的损失函数至关重要。通常情况下,在GAN中采用的是交叉熵损失函数来衡量真假样本间的差异[^2]。此外,还需要配置Adam优化算法以更新参数。
```python
criterion = nn.BCELoss()
d_optimizer = optim.Adam(dis.parameters(), lr=0.0002, betas=(0.5, 0.999))
g_optimizer = optim.Adam(gen.parameters(), lr=0.0002, betas=(0.5, 0.999))
```
#### 进行训练
在每次迭代过程中,先固定住生成器并仅调整判别器权重,使其更好地识别伪造品;接着冻结判别器并对生成器施加反向传播操作,从而提高其制造假货的能力。这个循环往复的过程将持续到满足特定条件为止。
```python
batch_size = 64
epochs = 200
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_loader = DataLoader(datasets.MNIST('./data', train=True, download=True, transform=transform), batch_size=batch_size, shuffle=True)
for epoch in range(epochs):
for i, data in enumerate(train_loader):
real_images, _ = data
valid_labels = torch.ones(batch_size).to(device)
fake_labels = torch.zeros(batch_size).to(device)
# 训练判别器
dis.zero_grad()
outputs = dis(real_images.to(device)).squeeze()
d_loss_real = criterion(outputs, valid_labels)
d_x = outputs.mean().item()
noise = torch.randn(batch_size, 100).to(device)
gen_imgs = gen(noise)
outputs = dis(gen_imgs.detach()).squeeze()
d_loss_fake = criterion(outputs, fake_labels)
d_g_z1 = outputs.mean().item()
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
d_optimizer.step()
# 训练生成器
gen.zero_grad()
outputs = dis(gen_imgs).squeeze()
g_loss = criterion(outputs, valid_labels)
d_g_z2 = outputs.mean().item()
g_loss.backward()
g_optimizer.step()
if i % 100 == 0:
print(f"[Epoch {epoch}/{epochs}] [Batch {i}/{len(train_loader)}] "
f"D loss: {d_loss.item():.4f}, G loss: {g_loss.item():.4f} "
f"D(x):{d_x:.4f}, D(G(z)): {d_g_z1:.4f}/{d_g_z2:.4f}")
```
阅读全文