GAN代码python
时间: 2023-11-16 10:01:32 浏览: 37
这里提供了一个简单的GAN代码实现,包括生成器和判别器的定义、损失函数的定义以及GAN的训练过程。其中生成器使用了多层感知机,判别器使用了卷积神经网络。你可以根据自己的需求进行修改和扩展。
```python
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
# 定义生成器
class Generator(nn.Module):
def __init__(self, input_size=100, output_size=784):
super(Generator, self).__init__()
self.fc1 = nn.Linear(input_size, 256)
self.fc2 = nn.Linear(256, 512)
self.fc3 = nn.Linear(512, 1024)
self.fc4 = nn.Linear(1024, output_size)
def forward(self, x):
x = F.leaky_relu(self.fc1(x), 0.2)
x = F.leaky_relu(self.fc2(x), 0.2)
x = F.leaky_relu(self.fc3(x), 0.2)
x = F.tanh(self.fc4(x))
return x
# 定义判别器
class Discriminator(nn.Module):
def __init__(self, input_size=784, output_size=1):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(input_size, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, output_size)
def forward(self, x):
x = F.leaky_relu(self.fc1(x), 0.2)
x = F.dropout(x, 0.3)
x = F.leaky_relu(self.fc2(x), 0.2)
x = F.dropout(x, 0.3)
x = F.sigmoid(self.fc3(x))
return x
# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# 训练GAN
for epoch in range(num_epochs):
for i, (images, _) in enumerate(data_loader):
# 训练判别器
discriminator.zero_grad()
real_images = Variable(images.view(-1, 28*28))
real_labels = Variable(torch.ones(images.size(0)))
fake_images = Variable(generator(torch.randn(images.size(0), 100).cuda()))
fake_labels = Variable(torch.zeros(images.size(0)))
d_loss_real = criterion(discriminator(real_images.cuda()), real_labels.cuda())
d_loss_fake = criterion(discriminator(fake_images.cuda()), fake_labels.cuda())
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
optimizer_D.step()
# 训练生成器
generator.zero_grad()
fake_images = generator(torch.randn(images.size(0), 100).cuda())
g_loss = criterion(discriminator(fake_images.cuda()), real_labels.cuda())
g_loss.backward()
optimizer_G.step()
```