ACGAN生成对抗网络训练Pytorch代码 生成指定数字手写数字图片
时间: 2023-05-11 18:02:33 浏览: 192
以下是一个简单的 ACGAN 生成对抗网络的 PyTorch 代码,用于生成指定数字的手写数字图片:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
import numpy as np
# 定义生成器网络
class Generator(nn.Module):
def __init__(self, z_dim=100, num_classes=10):
super(Generator, self).__init__()
self.z_dim = z_dim
self.num_classes = num_classes
self.fc1 = nn.Linear(z_dim + num_classes, 256)
self.fc2 = nn.Linear(256, 512)
self.fc3 = nn.Linear(512, 1024)
self.fc4 = nn.Linear(1024, 28*28)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def forward(self, z, labels):
inputs = torch.cat([z, labels], dim=1)
x = self.relu(self.fc1(inputs))
x = self.relu(self.fc2(x))
x = self.relu(self.fc3(x))
x = self.tanh(self.fc4(x))
return x.view(-1, 1, 28, 28)
# 定义判别器网络
class Discriminator(nn.Module):
def __init__(self, num_classes=10):
super(Discriminator, self).__init__()
self.num_classes = num_classes
self.conv1 = nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1)
self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
self.fc1 = nn.Linear(128*7*7 + num_classes, 1024)
self.fc2 = nn.Linear(1024, 1)
self.leaky_relu = nn.LeakyReLU(0.2)
self.sigmoid = nn.Sigmoid()
def forward(self, x, labels):
x = self.leaky_relu(self.conv1(x))
x = self.leaky_relu(self.conv2(x))
x = x.view(-1, 128*7*7)
inputs = torch.cat([x, labels], dim=1)
x = self.leaky_relu(self.fc1(inputs))
x = self.sigmoid(self.fc2(x))
return x
# 定义训练函数
def train(generator, discriminator, dataloader, num_epochs=200, z_dim=100, num_classes=10, lr=0.0002, beta1=0.5, beta2=0.999):
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, beta2))
d_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, beta2))
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(dataloader):
batch_size = images.size(0)
images = Variable(images)
labels = Variable(labels)
# 训练判别器
d_optimizer.zero_grad()
real_labels = Variable(torch.ones(batch_size))
fake_labels = Variable(torch.zeros(batch_size))
# 训练判别器使用真实图片
real_outputs = discriminator(images, labels)
d_loss_real = criterion(real_outputs, real_labels)
# 训练判别器使用生成器生成的假图片
z = Variable(torch.randn(batch_size, z_dim))
fake_labels = Variable(torch.LongTensor(np.random.randint(0, num_classes, batch_size)))
fake_images = generator(z, fake_labels)
fake_outputs = discriminator(fake_images, fake_labels)
d_loss_fake = criterion(fake_outputs, fake_labels)
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
d_optimizer.step()
# 训练生成器
g_optimizer.zero_grad()
z = Variable(torch.randn(batch_size, z_dim))
fake_labels = Variable(torch.LongTensor(np.random.randint(0, num_classes, batch_size)))
fake_images = generator(z, fake_labels)
fake_outputs = discriminator(fake_images, fake_labels)
g_loss = criterion(fake_outputs, real_labels)
g_loss.backward()
g_optimizer.step()
if (i+1) % 100 == 0:
print('Epoch [%d/%d], Step [%d/%d], d_loss: %.4f, g_loss: %.4f'
% (epoch+1, num_epochs, i+1, len(dataloader), d_loss.data[0], g_loss.data[0]))
# 加载 MNIST 数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# 定义模型和训练
generator = Generator()
discriminator = Discriminator()
train(generator, discriminator, train_dataloader)
```
这个代码可以生成指定数字的手写数字图片,你可以在 `train` 函数中指定要生成的数字,例如:
```python
z = Variable(torch.randn(10, 100))
labels = Variable(torch.LongTensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]))
images = generator(z, labels)
```
这将生成 0 到 9 的 10 个手写数字图片。
阅读全文