acgan自动生成动漫头像代码和数据集
时间: 2023-10-11 15:07:55 浏览: 194
ACGAN (Auxiliary Classifier GAN) 是一种生成对抗网络,可以用于生成具有特定属性的图像,如动漫头像。以下是一个基于PyTorch的ACGAN代码示例,以及一个动漫头像数据集。
首先,您需要准备一个动漫头像数据集。您可以从以下链接下载一个预处理好的动漫头像数据集:https://drive.google.com/file/d/1v9MlXr9TqXOQwqDyL0QzEgPq3zZvA6Cz/view?usp=sharing
解压缩数据集后,您将得到一个名为“faces”的文件夹,其中包含了所有的动漫头像。
接下来,您需要安装PyTorch和其他必要的库。
```python
!pip install torch torchvision
!pip install numpy matplotlib
```
然后,您可以使用以下代码训练ACGAN模型。
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import numpy as np
import matplotlib.pyplot as plt
# 定义生成器
class Generator(nn.Module):
def __init__(self, z_dim, num_classes, img_channels):
super(Generator, self).__init__()
self.z_dim = z_dim
self.num_classes = num_classes
self.img_channels = img_channels
self.label_emb = nn.Embedding(num_classes, num_classes)
self.generator = nn.Sequential(
nn.ConvTranspose2d(self.z_dim + self.num_classes, 256, 7, 1, 0, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(256, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(64, self.img_channels, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, noise, labels):
gen_input = torch.cat((self.label_emb(labels), noise), -1)
gen_input = gen_input.view(gen_input.size(0), gen_input.size(1), 1, 1)
img = self.generator(gen_input)
return img
# 定义判别器
class Discriminator(nn.Module):
def __init__(self, num_classes, img_channels):
super(Discriminator, self).__init__()
self.num_classes = num_classes
self.img_channels = img_channels
self.label_emb = nn.Embedding(num_classes, num_classes)
self.discriminator = nn.Sequential(
nn.Conv2d(self.img_channels + self.num_classes, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 1, 7, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, img, labels):
disc_input = torch.cat((img, self.label_emb(labels)), -1)
disc_input = disc_input.view(disc_input.size(0), disc_input.size(1), 1, 1)
validity = self.discriminator(disc_input)
return validity.view(-1, 1)
# 定义训练函数
def train(generator, discriminator, dataloader, num_epochs, z_dim, num_classes, device, lr):
generator.to(device)
discriminator.to(device)
criterion = nn.BCELoss()
criterion_class = nn.CrossEntropyLoss()
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
for epoch in range(num_epochs):
for i, (real_imgs, labels) in enumerate(dataloader):
real_imgs = real_imgs.to(device)
labels = labels.to(device)
valid = torch.ones(real_imgs.size(0), 1).to(device)
fake = torch.zeros(real_imgs.size(0), 1).to(device)
# 训练判别器
optimizer_d.zero_grad()
z = torch.randn(real_imgs.size(0), z_dim).to(device)
fake_labels = torch.randint(0, num_classes, (real_imgs.size(0),)).to(device)
fake_imgs = generator(z, fake_labels)
real_loss = criterion(discriminator(real_imgs, labels), valid)
fake_loss = criterion(discriminator(fake_imgs.detach(), fake_labels), fake)
d_loss = (real_loss + fake_loss) / 2
d_loss.backward()
optimizer_d.step()
# 训练生成器
optimizer_g.zero_grad()
z = torch.randn(real_imgs.size(0), z_dim).to(device)
gen_labels = torch.randint(0, num_classes, (real_imgs.size(0),)).to(device)
gen_imgs = generator(z, gen_labels)
g_loss = criterion(discriminator(gen_imgs, gen_labels), valid)
class_loss = criterion_class(generator.label_emb(gen_labels), gen_labels)
g_loss_total = g_loss + class_loss
g_loss_total.backward()
optimizer_g.step()
if i % 100 == 0:
print("[Epoch %d/%d] [Batch %d/%d] [D loss: %.4f] [G loss: %.4f / %.4f]"
% (epoch, num_epochs, i, len(dataloader), d_loss.item(), g_loss.item(), class_loss.item()))
if epoch % 5 == 0:
save_image(gen_imgs.data[:25], "images/%d.png" % epoch, nrow=5, normalize=True)
# 加载数据集
transform = transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
dataset = datasets.ImageFolder("faces", transform=transform)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
# 定义训练参数
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
z_dim = 100
num_classes = len(dataset.classes)
img_channels = 3
lr = 0.0002
num_epochs = 50
# 初始化生成器和判别器,并开始训练
generator = Generator(z_dim, num_classes, img_channels)
discriminator = Discriminator(num_classes, img_channels)
train(generator, discriminator, dataloader, num_epochs, z_dim, num_classes, device, lr)
```
训练过程需要一些时间,您可以在训练过程中保存一些生成的动漫头像,可以使用以下代码保存生成的图像。
```python
def save_image(imgs, path, nrow, normalize):
imgs = imgs.detach().cpu()
grid = torchvision.utils.make_grid(imgs, nrow=nrow, normalize=normalize)
torchvision.utils.save_image(grid, path)
```
最后,您可以使用以下代码生成一些新的动漫头像。
```python
# 加载模型
generator = Generator(z_dim, num_classes, img_channels)
generator.load_state_dict(torch.load("generator.pth"))
generator.to(device)
generator.eval()
# 生成新的动漫头像
z = torch.randn(25, z_dim).to(device)
labels = torch.randint(0, num_classes, (25,)).to(device)
imgs = generator(z, labels)
save_image(imgs.data, "generated.png", nrow=5, normalize=True)
```
运行上面的代码后,您将在当前目录下看到一个名为“generated.png”的文件,其中包含25个生成的动漫头像。
阅读全文