来个GAN简单的代码实例
时间: 2023-12-04 08:40:56 浏览: 89
生成对抗网络(GAN)实例 代码+数据集
5星 · 资源好评率100%
以下是一个简单的GAN代码实例,用于生成手写数字图片:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
# 定义生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 784),
nn.Tanh()
)
def forward(self, x):
img = self.main(x)
img = img.view(-1, 28, 28) # 转换成图片的形式
return img
# 定义判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Linear(784, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
x = x.view(-1, 784)
validity = self.main(x)
return validity
# 定义训练函数
def train(generator, discriminator, dataloader, optimizer_G, optimizer_D, criterion, device):
for epoch in range(200):
for i, (imgs, _) in enumerate(dataloader):
# 训练判别器
optimizer_D.zero_grad()
real_imgs = imgs.to(device)
real_labels = torch.ones((real_imgs.size(0), 1)).to(device)
fake_labels = torch.zeros((real_imgs.size(0), 1)).to(device)
# 生成假图片
z = torch.randn((real_imgs.size(0), 100)).to(device)
fake_imgs = generator(z)
# 判别真假图片
real_loss = criterion(discriminator(real_imgs), real_labels)
fake_loss = criterion(discriminator(fake_imgs.detach()), fake_labels)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_D.step()
# 训练生成器
optimizer_G.zero_grad()
z = torch.randn((real_imgs.size(0),100)).to(device)
fake_imgs = generator(z)
g_loss = criterion(discriminator(fake_imgs), real_labels)
g_loss.backward()
optimizer_G.step()
# 打印损失
print("[Epoch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, 200, d_loss.item(), g_loss.item()))
# 保存生成的图片
if epoch % 10 == 0:
z = torch.randn((25, 100)).to(device)
gen_imgs = generator(z).detach().cpu()
save_image(gen_imgs.data, "images/%d.png" % epoch, nrow=5, normalize=True)
# 加载数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
mnist_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
dataloader = DataLoader(mnist_dataset, batch_size=128, shuffle=True)
# 定义设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义模型、优化器和损失函数
generator = Generator().to(device)
discriminator = Discriminator().to(device)
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
criterion = nn.BCELoss()
# 训练模型
train(generator, discriminator, dataloader, optimizer_G, optimizer_D, criterion, device)
```
阅读全文