mnist手写数字生成pytorch
时间: 2023-10-10 21:06:52 浏览: 43
以下是一个基于GAN的mnist手写数字生成的PyTorch代码示例:
```python
import torch
import torch.nn as nn
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
# 定义生成器
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Generator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
# 定义判别器
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.sigmoid(self.fc1(x))
x = self.sigmoid(self.fc2(x))
x = self.sigmoid(self.fc3(x))
return x
# 定义超参数
input_size = 100
hidden_size = 256
output_size = 784
batch_size = 128
num_epochs = 200
# 加载MNIST数据集
train_dataset = MNIST(root='./data', train=True, transform=ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 初始化网络
G = Generator(input_size, hidden_size, output_size)
D = Discriminator(output_size, hidden_size)
# 定义损失函数和优化器
criterion = nn.BCELoss()
lr = 0.0002
G_optimizer = torch.optim.Adam(G.parameters(), lr=lr)
D_optimizer = torch.optim.Adam(D.parameters(), lr=lr)
# 定义真实和假的标签
real_label = torch.ones(batch_size, 1)
fake_label = torch.zeros(batch_size, 1)
# 训练网络
for epoch in range(num_epochs):
for i, (images, _) in enumerate(train_loader):
# 定义真实和假的图像
real_images = images.view(batch_size, -1)
z = torch.randn(batch_size, input_size)
fake_images = G(z)
# 训练判别器
D_real_loss = criterion(D(real_images), real_label)
D_fake_loss = criterion(D(fake_images.detach()), fake_label)
D_loss = D_real_loss + D_fake_loss
D_optimizer.zero_grad()
D_loss.backward()
D_optimizer.step()
# 训练生成器
G_loss = criterion(D(fake_images), real_label)
G_optimizer.zero_grad()
G_loss.backward()
G_optimizer.step()
# 打印损失
if (i+1) % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], D Loss: {:.4f}, G Loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, len(train_loader), D_loss.item(), G_loss.item()))
# 保存模型
torch.save(G.state_dict(), 'generator.pth')
```
在训练完成后,可以使用生成器来生成新的手写数字图像,例如:
```python
import matplotlib.pyplot as plt
import numpy as np
# 加载生成器
G = Generator(input_size, hidden_size, output_size)
G.load_state_dict(torch.load('generator.pth'))
# 生成图像
z = torch.randn(1, input_size)
fake_image = G(z).detach().numpy()
fake_image = np.reshape(fake_image, (28, 28))
# 显示图像
plt.imshow(fake_image, cmap='gray')
plt.show()
```
这样就可以生成一个随机的手写数字图像了。