基于DCGAN的MNIST手写数字生成代码
时间: 2024-01-15 20:12:35 浏览: 84
CGAN生成MNIST手写数据
以下是基于DCGAN的MNIST手写数字生成的Python代码示例:
```python
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.optim as optim
import numpy as np
# 定义生成器网络
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc1 = nn.Linear(100, 128)
self.fc2 = nn.Linear(128, 256)
self.fc3 = nn.Linear(256, 512)
self.fc4 = nn.Linear(512, 784)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.relu(self.fc3(x))
x = self.tanh(self.fc4(x))
return x
# 定义判别器网络
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(784, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 1)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.sigmoid(self.fc3(x))
return x
# 准备数据集
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=128,
shuffle=True, num_workers=2)
# 初始化生成器和判别器
G = Generator()
D = Discriminator()
# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer_G = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
# 训练模型
for epoch in range(50):
for i, data in enumerate(trainloader, 0):
# 更新判别器
D.zero_grad()
real_images = data[0].view(-1, 784)
real_labels = torch.ones(real_images.size()[0], 1)
fake_labels = torch.zeros(real_images.size()[0], 1)
# 训练鉴别器以识别真实图片
real_output = D(real_images)
real_loss = criterion(real_output, real_labels)
# 训练鉴别器以识别生成的图片
noise = torch.randn(real_images.size()[0], 100)
fake_images = G(noise)
fake_output = D(fake_images.detach())
fake_loss = criterion(fake_output, fake_labels)
# 更新判别器的权重
d_loss = real_loss + fake_loss
d_loss.backward()
optimizer_D.step()
# 更新生成器
G.zero_grad()
noise = torch.randn(real_images.size()[0], 100)
fake_images = G(noise)
fake_output = D(fake_images)
g_loss = criterion(fake_output, real_labels)
g_loss.backward()
optimizer_G.step()
# 定期生成一些图像来检查生成器的效果
if epoch % 5 == 0:
noise = torch.randn(25, 100)
generated_images = G(noise).detach().numpy()
generated_images = np.reshape(generated_images, (25, 28, 28))
generated_images = (generated_images + 1) / 2
torchvision.utils.save_image(torch.from_numpy(generated_images), 'generated_images_{}.png'.format(epoch))
print('Finished Training')
```
在训练过程中,生成器会不断生成手写数字图像,判别器会不断判断这些图像是真实的MNIST图像还是生成器生成的假图像。通过不断调整生成器和判别器的权重,生成器可以产生逼真的手写数字图像。最终,我们可以使用生成器来生成新的手写数字图像。
阅读全文