ACGAN自动生成动漫头像代码PyTorch
时间: 2024-06-09 10:04:34 浏览: 195
以下是使用 PyTorch 实现 ACGAN 自动生成动漫头像的代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
# 设置随机数种子
torch.manual_seed(1)
# 设置超参数
batch_size = 64
num_epochs = 200
z_dimension = 100
num_classes = 10
image_size = 64
# 加载数据集
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5] * 3, std=[0.5] * 3)
])
train_dataset = datasets.ImageFolder('data', transform)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 定义生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.label_emb = nn.Embedding(num_classes, num_classes)
self.layer1 = nn.Sequential(
nn.Linear(z_dimension + num_classes, 128 * 8 * 8),
nn.BatchNorm1d(128 * 8 * 8),
nn.ReLU(inplace=True)
)
self.layer2 = nn.Sequential(
nn.ConvTranspose2d(128, 64, 4, 2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True)
)
self.layer3 = nn.Sequential(
nn.ConvTranspose2d(64, 1, 4, 2, padding=1),
nn.Tanh()
)
def forward(self, x, label):
x = torch.cat([x, self.label_emb(label)], dim=1)
x = self.layer1(x)
x = x.view(x.shape[0], 128, 8, 8)
x = self.layer2(x)
x = self.layer3(x)
return x
# 定义判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.label_emb = nn.Embedding(num_classes, num_classes)
self.layer1 = nn.Sequential(
nn.Conv2d(1 + num_classes, 64, 4, 2, padding=1),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True)
)
self.layer2 = nn.Sequential(
nn.Conv2d(64, 128, 4, 2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True)
)
self.layer3 = nn.Sequential(
nn.Linear(128 * 8 * 8, 1),
nn.Sigmoid()
)
def forward(self, x, label):
x = torch.cat([x, self.label_emb(label).unsqueeze(2).unsqueeze(3)], dim=1)
x = self.layer1(x)
x = self.layer2(x)
x = x.view(x.shape[0], -1)
x = self.layer3(x)
return x
# 定义判别器loss函数
def discriminator_loss(logits_real, logits_fake):
loss = None
######################
# 代码填写处 #
######################
return loss
# 定义生成器loss函数
def generator_loss(logits_fake):
loss = None
######################
# 代码填写处 #
######################
return loss
# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()
# 定义优化器
optimizer_g = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# 定义标签
fixed_z = Variable(torch.randn(num_classes, z_dimension))
fixed_label = Variable(torch.LongTensor([i for i in range(num_classes)]))
# 开始训练
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
batch_size = images.size(0)
# 定义实际标签和假标签
real_labels = Variable(torch.ones(batch_size))
fake_labels = Variable(torch.zeros(batch_size))
# 定义图片和标签
real_images = Variable(images)
labels = Variable(labels)
# 训练判别器
discriminator.zero_grad()
# 计算真实图片的loss
outputs = discriminator(real_images, labels)
d_loss_real = discriminator_loss(outputs, real_labels)
# 计算假图片的loss
z = Variable(torch.randn(batch_size, z_dimension))
fake_images = generator(z, labels)
outputs = discriminator(fake_images, labels)
d_loss_fake = discriminator_loss(outputs, fake_labels)
# 计算判别器的总loss
d_loss = d_loss_real + d_loss_fake
d_loss.backward()
optimizer_d.step()
# 训练生成器
generator.zero_grad()
# 生成假图片
z = Variable(torch.randn(batch_size, z_dimension))
fake_images = generator(z, labels)
# 计算假图片的loss
outputs = discriminator(fake_images, labels)
g_loss = generator_loss(outputs)
g_loss.backward()
optimizer_g.step()
if (i+1) % 100 == 0:
print("Epoch[{}/{}], d_loss: {:.6f}, g_loss: {:.6f}".format(
epoch+1, num_epochs, d_loss.data[0], g_loss.data[0]
))
# 保存生成的图片
fake_images = generator(fixed_z, fixed_label)
torchvision.utils.save_image(fake_images.data, './fake_images-{}.png'.format(epoch+1), nrow=10, normalize=True)
# 保存模型
torch.save(generator.state_dict(), './generator.pth')
torch.save(discriminator.state_dict(), './discriminator.pth')
# 显示生成的图片
def show_images(images):
images = np.transpose(images.detach().numpy(), (0, 2, 3, 1))
fig, axes = plt.subplots(1, len(images), figsize=(10, 10))
for i, image in enumerate(images):
axes[i].imshow((image + 1) / 2)
axes[i].axis('off')
plt.show()
# 加载模型并生成图片
generator.load_state_dict(torch.load('./generator.pth'))
fake_images = generator(fixed_z, fixed_label)
show_images(fake_images)
```
在运行代码之前,需要将数据集放在 `data` 目录下,并且每个类别的图片应该在一个单独的文件夹内。可以通过以下代码下载已经准备好的数据集:
```python
!wget https://github.com/Hongyanlee0614/ACGAN-anime-face/raw/main/data.zip
!unzip data.zip
```
这个代码实现了 ACGAN 自动生成动漫头像的过程,其中生成器和判别器都使用了条件 GAN,即在输入噪声的同时还输入了类别标签。在训练过程中,先训练判别器,再训练生成器,不断交替进行,直到训练完成。训练完成后,可以加载保存的模型,并使用生成器生成新的头像。
阅读全文