gan生成mnist
时间: 2023-11-14 13:09:12 浏览: 90
生成 MNIST 数据集是深度学习中的一个经典问题,可以使用 GAN(Generative Adversarial Networks)来实现。GAN 是一种无监督学习的方法,它由两个神经网络组成:生成器和判别器。生成器负责生成假的图像,判别器负责判断图像是真实的还是假的。两个网络相互对抗,最终生成器可以生成越来越逼真的图像。
以下是使用 PyTorch 实现 GAN 生成 MNIST 数据集的示例代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
# 定义生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc1 = nn.Linear(100, 256)
self.fc2 = nn.Linear(256, 512)
self.fc3 = nn.Linear(512, 784)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.tanh(self.fc3(x))
return x
# 定义判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(784, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.sigmoid(self.fc3(x))
return x
# 定义训练函数
def train(generator, discriminator, dataloader, optimizer_g, optimizer_d, criterion):
for epoch in range(100):
for i, (real_images, _) in enumerate(dataloader):
# 训练判别器
real_labels = torch.ones(real_images.size(0), 1)
fake_labels = torch.zeros(real_images.size(0), 1)
real_images = real_images.view(real_images.size(0), -1)
real_images = real_images.to(device)
real_labels = real_labels.to(device)
fake_labels = fake_labels.to(device)
optimizer_d.zero_grad()
real_outputs = discriminator(real_images)
real_loss = criterion(real_outputs, real_labels)
noise = torch.randn(real_images.size(0), 100)
noise = noise.to(device)
fake_images = generator(noise)
fake_outputs = discriminator(fake_images.detach())
fake_loss = criterion(fake_outputs, fake_labels)
d_loss = real_loss + fake_loss
d_loss.backward()
optimizer_d.step()
# 训练生成器
optimizer_g.zero_grad()
noise = torch.randn(real_images.size(0), 100)
noise = noise.to(device)
fake_images = generator(noise)
fake_outputs = discriminator(fake_images)
g_loss = criterion(fake_outputs, real_labels)
g_loss.backward()
optimizer_g.step()
print('Epoch [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}'.format(epoch+1, 100, d_loss.item(), g_loss.item()))
# 加载数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.5], std=[0.5])
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# 定义超参数
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
lr = 0.0002
betas = (0.5, 0.999)
criterion = nn.BCELoss()
# 初始化生成器和判别器
generator = Generator().to(device)
discriminator = Discriminator().to(device)
# 定义优化器
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=betas)
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=betas)
# 训练模型
train(generator, discriminator, train_dataloader, optimizer_g, optimizer_d, criterion)
```
在训练完成后,可以使用生成器生成假的 MNIST 图像。以下是生成图像的示例代码:
```python
import matplotlib.pyplot as plt
# 生成假的图像
noise = torch.randn(64, 100)
noise = noise.to(device)
fake_images = generator(noise).detach().cpu()
# 显示图像
fig, ax = plt.subplots(nrows=8, ncols=8, figsize=(10, 10))
for i in range(8):
for j in range(8):
ax[i][j].imshow(fake_images[i*8+j].view(28, 28), cmap='gray')
ax[i][j].axis('off')
plt.show()
```
阅读全文