ACGAN生成对抗网络训练Pytorch代码 生成指定数字手写数字图片
时间: 2023-05-10 19:55:05 浏览: 61
以下是一个简单的 ACGAN 生成对抗网络训练 Pytorch 代码,可以生成指定数字的手写数字图片:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
import numpy as np
# 定义生成器和判别器的网络结构
class Generator(nn.Module):
def __init__(self, nz, ngf, nc, num_classes):
super(Generator, self).__init__()
self.num_classes = num_classes
self.fc = nn.Sequential(
nn.Linear(nz + num_classes, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(True),
nn.Linear(1024, 7 * 7 * 128),
nn.BatchNorm1d(7 * 7 * 128),
nn.ReLU(True),
)
self.conv = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, nc, 4, 2, 1),
nn.Tanh(),
)
def forward(self, z, labels):
c = np.zeros((z.shape[0], self.num_classes))
c[range(z.shape[0]), labels] = 1
c = torch.from_numpy(c).float()
if z.is_cuda:
c = c.cuda()
z = torch.cat([z, c], 1)
out = self.fc(z)
out = out.view(out.size(0), 128, 7, 7)
out = self.conv(out)
return out
class Discriminator(nn.Module):
def __init__(self, ndf, nc, num_classes):
super(Discriminator, self).__init__()
self.num_classes = num_classes
self.conv = nn.Sequential(
nn.Conv2d(nc, 64, 4, 2, 1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
)
self.fc = nn.Sequential(
nn.Linear(128 * 7 * 7 + num_classes, 1024),
nn.BatchNorm1d(1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 1),
nn.Sigmoid(),
)
def forward(self, x, labels):
out = self.conv(x)
out = out.view(out.size(0), -1)
c = np.zeros((out.shape[0], self.num_classes))
c[range(out.shape[0]), labels] = 1
c = torch.from_numpy(c).float()
if x.is_cuda:
c = c.cuda()
out = torch.cat([out, c], 1)
out = self.fc(out)
return out
# 定义训练函数
def train(dataloader, discriminator, generator, optimizer_d, optimizer_g, criterion, num_epochs, nz, num_classes):
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(dataloader):
batch_size = images.size(0)
images = Variable(images)
labels = Variable(labels)
real_labels = Variable(torch.ones(batch_size))
fake_labels = Variable(torch.zeros(batch_size))
if images.is_cuda:
real_labels = real_labels.cuda()
fake_labels = fake_labels.cuda()
labels = labels.cuda()
images = images.cuda()
# 训练判别器
optimizer_d.zero_grad()
outputs = discriminator(images, labels)
real_loss = criterion(outputs, real_labels)
real_loss.backward()
z = Variable(torch.randn(batch_size, nz))
if z.is_cuda:
z = z.cuda()
fake_labels_ = Variable(torch.LongTensor(np.random.randint(0, num_classes, batch_size)))
if fake_labels_.is_cuda:
fake_labels_ = fake_labels_.cuda()
fake_images = generator(z, fake_labels_)
outputs = discriminator(fake_images, fake_labels_)
fake_loss = criterion(outputs, fake_labels)
fake_loss.backward()
optimizer_d.step()
# 训练生成器
optimizer_g.zero_grad()
z = Variable(torch.randn(batch_size, nz))
if z.is_cuda:
z = z.cuda()
fake_labels_ = Variable(torch.LongTensor(np.random.randint(0, num_classes, batch_size)))
if fake_labels_.is_cuda:
fake_labels_ = fake_labels_.cuda()
fake_images = generator(z, fake_labels_)
outputs = discriminator(fake_images, fake_labels_)
g_loss = criterion(outputs, real_labels)
g_loss.backward()
optimizer_g.step()
# 打印训练信息
if (i + 1) % 100 == 0:
print('Epoch [%d/%d], Step [%d/%d], d_loss: %.4f, g_loss: %.4f'
% (epoch + 1, num_epochs, i + 1, len(dataloader), (real_loss + fake_loss).data[0], g_loss.data[0]))
# 定义超参数
batch_size = 128
num_epochs = 200
nz = 100
ngf = 64
ndf = 64
nc = 1
num_classes = 10
lr = 0.0002
beta1 = 0.5
# 加载 MNIST 数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 初始化生成器和判别器
generator = Generator(nz, ngf, nc, num_classes)
discriminator = Discriminator(ndf, nc, num_classes)
if torch.cuda.is_available():
generator.cuda()
discriminator.cuda()
# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
# 训练模型
train(train_dataloader, discriminator, generator, optimizer_d, optimizer_g, criterion, num_epochs, nz, num_classes)
```
这个代码可以生成指定数字的手写数字图片,你可以通过修改 `fake_labels_` 的值来指定要生成的数字。
相关推荐
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![zip](https://img-home.csdnimg.cn/images/20210720083736.png)
![ipynb](https://img-home.csdnimg.cn/images/20210720083646.png)
![-](https://csdnimg.cn/download_wenku/file_type_column_c1.png)
![-](https://csdnimg.cn/download_wenku/file_type_column_c1.png)
![-](https://csdnimg.cn/download_wenku/file_type_lunwen.png)
![-](https://csdnimg.cn/download_wenku/file_type_column_c1.png)
![-](https://csdnimg.cn/download_wenku/file_type_column_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)
![](https://csdnimg.cn/download_wenku/file_type_ask_c1.png)