ACGAN自动生成动漫头像代码PyTorch
时间: 2024-06-09 12:04:32 浏览: 174
基于ACGAN-动漫头像自动生成系统
5星 · 资源好评率100%
以下是基于 PyTorch 实现的 ACGAN 自动生成动漫头像的代码,主要参考了 https://github.com/znxlwm/pytorch-MNIST-CelebA-GAN-DCGAN/blob/master/pytorch_AC_GAN.py:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import os
# 定义网络结构
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.label_emb = nn.Embedding(10, 10)
self.model = nn.Sequential(
nn.Linear(110, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 784),
nn.Tanh()
)
def forward(self, z, labels):
c = self.label_emb(labels)
x = torch.cat([z, c], 1)
out = self.model(x)
return out.view(out.size(0), 1, 28, 28)
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.label_emb = nn.Embedding(10, 10)
self.model = nn.Sequential(
nn.Linear(794, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x, labels):
c = self.label_emb(labels)
x = x.view(x.size(0), -1)
x = torch.cat([x, c], 1)
out = self.model(x)
return out
# 定义训练函数
def train(generator, discriminator, dataloader, num_epochs=200):
# 定义损失函数和优化器
criterion = nn.BCELoss()
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# 定义固定的噪声和标签,用于生成样本
fixed_noise = Variable(torch.randn(100, 100))
fixed_labels = Variable(torch.LongTensor([i for _ in range(10) for i in range(10)]))
# 开始训练
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(dataloader):
batch_size = images.size(0)
# 定义真实样本和标签
real_images = Variable(images)
real_labels = Variable(labels)
# 定义假样本和标签
noise = Variable(torch.randn(batch_size, 100))
fake_labels = Variable(torch.LongTensor([torch.randint(0, 10, (1,)).item() for _ in range(batch_size)]))
fake_images = generator(noise, fake_labels)
# 训练判别器
optimizer_d.zero_grad()
# 计算判别器的损失函数
real_outputs = discriminator(real_images, real_labels)
real_loss = criterion(real_outputs, Variable(torch.ones(batch_size, 1)))
fake_outputs = discriminator(fake_images.detach(), fake_labels)
fake_loss = criterion(fake_outputs, Variable(torch.zeros(batch_size, 1)))
d_loss = real_loss + fake_loss
# 反向传播并更新判别器参数
d_loss.backward()
optimizer_d.step()
# 训练生成器
optimizer_g.zero_grad()
# 计算生成器的损失函数
noise = Variable(torch.randn(batch_size, 100))
fake_labels = Variable(torch.LongTensor([torch.randint(0, 10, (1,)).item() for _ in range(batch_size)]))
fake_images = generator(noise, fake_labels)
g_outputs = discriminator(fake_images, fake_labels)
g_loss = criterion(g_outputs, Variable(torch.ones(batch_size, 1)))
# 反向传播并更新生成器参数
g_loss.backward()
optimizer_g.step()
if (i+1) % 100 == 0:
print("Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}"
.format(epoch+1, num_epochs, i+1, len(dataloader), d_loss.item(), g_loss.item()))
# 保存生成器状态和生成的样本
if (epoch+1) % 10 == 0:
if not os.path.exists('./ACGAN_images'):
os.mkdir('./ACGAN_images')
fake_images = generator(fixed_noise, fixed_labels)
save_image(fake_images.data, './ACGAN_images/{}_{}.png'.format(epoch+1, d_loss.item()), nrow=10, normalize=True)
# 加载数据集
dataset = datasets.ImageFolder(root='./data', transform=transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]))
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)
# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()
# 训练模型
train(generator, discriminator, dataloader, num_epochs=200)
```
其中,Generator 和 Discriminator 分别定义了生成器和判别器的网络结构,train 函数定义了训练过程,包括优化器、损失函数以及参数更新等。在训练过程中,我们还保存了一些生成的样本,以便后续查看训练效果。
阅读全文