GAN 代码 pytorch实现
时间: 2023-07-13 18:07:25 浏览: 119
以下是一个简单的 GAN 在 PyTorch 中的实现:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
# 定义生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc = nn.Sequential(
nn.Linear(100, 128),
nn.ReLU(),
nn.Linear(128, 784),
nn.Tanh()
)
def forward(self, x):
x = self.fc(x)
return x.view(-1, 1, 28, 28)
# 定义鉴别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.fc = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, x):
x = x.view(-1, 784)
x = self.fc(x)
return x
# 定义超参数
batch_size = 128
lr = 0.0002
z_dim = 100
num_epochs = 100
# 加载数据集
from torchvision import datasets, transforms
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
train_dataset = datasets.MNIST(root='./data/', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 定义生成器和鉴别器
generator = Generator()
discriminator = Discriminator()
# 定义优化器和损失函数
optimizer_g = optim.Adam(generator.parameters(), lr=lr)
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr)
criterion = nn.BCELoss()
# 开始训练
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(train_loader):
batch_size = imgs.size(0)
real_imgs = Variable(imgs.cuda())
z = Variable(torch.randn(batch_size, z_dim).cuda())
# 训练鉴别器
optimizer_d.zero_grad()
real_loss = criterion(discriminator(real_imgs), Variable(torch.ones(batch_size, 1).cuda()))
fake_imgs = generator(z)
fake_loss = criterion(discriminator(fake_imgs.detach()), Variable(torch.zeros(batch_size, 1).cuda()))
d_loss = real_loss + fake_loss
d_loss.backward()
optimizer_d.step()
# 训练生成器
optimizer_g.zero_grad()
g_loss = criterion(discriminator(fake_imgs), Variable(torch.ones(batch_size, 1).cuda()))
g_loss.backward()
optimizer_g.step()
# 打印损失
if (i + 1) % 100 == 0:
print('Epoch [%d/%d], Step [%d/%d], d_loss: %.4f, g_loss: %.4f' %
(epoch + 1, num_epochs, i + 1, len(train_loader), d_loss.data[0], g_loss.data[0]))
# 保存模型
torch.save(generator.state_dict(), 'generator.pkl')
torch.save(discriminator.state_dict(), 'discriminator.pkl')
```
这是一个简单的 GAN,它用于生成手写数字图像。其中,生成器和鉴别器都是简单的全连接神经网络,损失函数使用二元交叉熵。训练过程中,先训练鉴别器,再训练生成器。在训练完成后,保存生成器和鉴别器的模型参数。
阅读全文