ACGAN自动生成动漫头像代码PyTorch
时间: 2024-01-16 21:03:13 浏览: 136
Anim_GAN:PyTorch-Gan自动生成的动漫头像
以下是一个基于PyTorch实现的ACGAN自动生成动漫头像的代码示例:
首先,我们需要导入所需的PyTorch库和其他必要的库:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import random
```
然后,我们需要定义一些超参数:
```python
batch_size = 64
image_size = 64
num_epochs = 100
num_classes = 10
latent_dim = 100
lr = 0.0002
beta1 = 0.5
beta2 = 0.999
```
接下来,我们需要定义数据加载器:
```python
transform = transforms.Compose([
transforms.Resize(image_size),
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = datasets.ImageFolder(root='./data', transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
```
然后,我们需要定义生成器和判别器模型:
```python
class Generator(nn.Module):
def __init__(self, latent_dim, num_classes, image_size):
super(Generator, self).__init__()
self.latent_dim = latent_dim
self.num_classes = num_classes
self.image_size = image_size
self.label_emb = nn.Embedding(num_classes, latent_dim)
self.model = nn.Sequential(
nn.Linear(latent_dim + num_classes, 128 * (image_size // 4) ** 2),
nn.BatchNorm1d(128 * (image_size // 4) ** 2),
nn.LeakyReLU(0.2, inplace=True),
nn.Reshape((128, image_size // 4, image_size // 4)),
nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.2, inplace=True),
nn.ConvTranspose2d(64, 3, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, noise, labels):
gen_input = torch.cat((self.label_emb(labels), noise), -1)
img = self.model(gen_input)
return img
class Discriminator(nn.Module):
def __init__(self, num_classes, image_size):
super(Discriminator, self).__init__()
self.num_classes = num_classes
self.image_size = image_size
self.label_emb = nn.Embedding(num_classes, image_size ** 2)
self.model = nn.Sequential(
nn.Conv2d(3, 64, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, 4, 2, 1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, 4, 2, 1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, 4, 2, 1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, num_classes + 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, img, labels):
d_in = img
h = self.model(d_in)
return h.view(-1, self.num_classes + 1)
```
接下来,我们需要定义损失函数和优化器:
```python
criterion = nn.BCELoss()
dis_criterion = nn.CrossEntropyLoss()
gen = Generator(latent_dim, num_classes, image_size)
dis = Discriminator(num_classes, image_size)
gen.cuda()
dis.cuda()
criterion.cuda()
dis_criterion.cuda()
opt_gen = optim.Adam(gen.parameters(), lr=lr, betas=(beta1, beta2))
opt_dis = optim.Adam(dis.parameters(), lr=lr, betas=(beta1, beta2))
```
然后,我们定义训练循环:
```python
for epoch in range(num_epochs):
for i, (imgs, labels) in enumerate(dataloader):
batch_size = imgs.size(0)
real_imgs = imgs.cuda()
labels = labels.cuda()
# Train Discriminator
opt_dis.zero_grad()
real_validity = dis(real_imgs, labels)
noise = torch.randn(batch_size, latent_dim).cuda()
fake_labels = torch.randint(0, num_classes, (batch_size,)).cuda()
fake_imgs = gen(noise, fake_labels)
fake_validity = dis(fake_imgs, fake_labels)
real_loss = criterion(real_validity, torch.ones(batch_size, 1).cuda())
fake_loss = criterion(fake_validity, torch.zeros(batch_size, 1).cuda())
dis_loss = real_loss + fake_loss
dis_loss.backward()
opt_dis.step()
# Train Generator
opt_gen.zero_grad()
noise = torch.randn(batch_size, latent_dim).cuda()
fake_labels = torch.randint(0, num_classes, (batch_size,)).cuda()
fake_imgs = gen(noise, fake_labels)
validity = dis(fake_imgs, fake_labels)
gen_loss = criterion(validity, torch.ones(batch_size, 1).cuda())
gen_loss.backward()
opt_gen.step()
if i % 100 == 0:
print('[%d/%d][%d/%d] D_loss: %.4f G_loss: %.4f'
% (epoch+1, num_epochs, i, len(dataloader),
dis_loss.item(), gen_loss.item()))
if epoch % 10 == 0:
vutils.save_image(fake_imgs.detach(),
'./results/fake_samples_epoch_%03d.png' % epoch,
normalize=True)
```
最后,我们可以使用生成器模型生成一些新的动漫头像:
```python
noise = torch.randn(64, latent_dim).cuda()
labels = torch.randint(0, num_classes, (64,)).cuda()
fake_imgs = gen(noise, labels)
vutils.save_image(fake_imgs.detach(),
'./results/fake_samples_final.png',
normalize=True)
```
这样就完成了ACGAN自动生成动漫头像的代码实现。
阅读全文