GAN代码Pytorch
时间: 2024-05-11 07:13:48 浏览: 79
以下是一个简单的GAN代码示例,使用PyTorch框架:
```
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
# 定义生成器模型
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc1 = nn.Linear(100, 128)
self.fc2 = nn.Linear(128, 256)
self.fc3 = nn.Linear(256, 784)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def forward(self, z):
z = self.relu(self.fc1(z))
z = self.relu(self.fc2(z))
z = self.tanh(self.fc3(z))
return z
# 定义判别器模型
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(784, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.sigmoid(self.fc3(x))
return x
# 初始化模型
generator = Generator()
discriminator = Discriminator()
# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002)
# 训练GAN
batch_size = 128
num_epochs = 100
for epoch in range(num_epochs):
for i in range(len(train_data) // batch_size):
# 训练判别器
discriminator.zero_grad()
real_images = train_data[i * batch_size: (i+1) * batch_size]
real_labels = torch.ones(batch_size, 1)
fake_images = generator(torch.randn(batch_size, 100))
fake_labels = torch.zeros(batch_size, 1)
real_outputs = discriminator(real_images)
fake_outputs = discriminator(fake_images)
d_loss_real = criterion(real_outputs, real_labels)
d_loss_fake = criterion(fake_outputs, fake_labels)
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
optimizer_d.step()
# 训练生成器
generator.zero_grad()
z = torch.randn(batch_size, 100)
fake_images = generator(z)
fake_labels = torch.ones(batch_size, 1)
fake_outputs = discriminator(fake_images)
g_loss = criterion(fake_outputs, fake_labels)
g_loss.backward()
optimizer_g.step()
# 输出损失值
print('[%d/%d] d_loss: %.3f, g_loss: %.3f' % (epoch+1, num_epochs, d_loss.item(), g_loss.item()))
# 生成样本图片
with torch.no_grad():
z = torch.randn(16, 100)
samples = generator(z).numpy()
fig, axes = plt.subplots(4, 4, figsize=(4, 4))
for i, ax in enumerate(axes.flat):
ax.imshow(samples[i].reshape(28, 28), cmap='gray')
ax.axis('off')
plt.savefig('samples/sample_%03d.png' % epoch)
plt.close()
```
在这个示例中,我们定义了一个简单的生成器和判别器模型,并使用BCELoss作为损失函数,使用Adam优化器进行优化。我们使用MNIST数据集进行训练,并在每个epoch结束时生成一些样本图片。最终的结果将保存在`samples`文件夹中。
阅读全文