ACGAN自动生成动漫头像PyTorch代码和 数据集
时间: 2024-06-09 07:04:32 浏览: 194
以下是ACGAN自动生成动漫头像的PyTorch代码和数据集:
## 数据集
我们将使用动漫头像数据集,该数据集包含10,000个大小为64x64的图像。您可以从以下链接下载数据集:
https://drive.google.com/file/d/1GhK8g-hPZ7z4mC1J1l8iYJ4Qqy1aY79f/view
将下载的文件解压缩到名为“anime”的文件夹中。
## PyTorch代码
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from torchvision.utils import save_image
from PIL import Image
import glob
# 超参数
batch_size = 128
lr = 0.0002
latent_dim = 100
num_classes = 10
num_epochs = 200
# 设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# 转换图像
transform = transforms.Compose([
transforms.Resize((64, 64)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 动漫头像数据集
class AnimeDataset(Dataset):
def __init__(self, root_dir, transform=None):
self.root_dir = root_dir
self.transform = transform
self.images = glob.glob(root_dir + '/*.png')
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
img_path = self.images[idx]
image = Image.open(img_path)
if self.transform:
image = self.transform(image)
return image
# 生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.label_emb = nn.Embedding(num_classes, num_classes)
self.model = nn.Sequential(
nn.Linear(latent_dim + num_classes, 128),
nn.BatchNorm1d(128, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(128, 256),
nn.BatchNorm1d(256, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 512),
nn.BatchNorm1d(512, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, 64*64*3),
nn.Tanh()
)
def forward(self, noise, labels):
gen_input = torch.cat((self.label_emb(labels), noise), -1)
img = self.model(gen_input)
img = img.view(img.size(0), 3, 64, 64)
return img
# 判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.label_emb = nn.Embedding(num_classes, num_classes)
self.model = nn.Sequential(
nn.Linear(num_classes + 64*64*3, 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 512),
nn.Dropout(0.4),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 1),
nn.Sigmoid()
)
def forward(self, img, labels):
img = img.view(img.size(0), -1)
d_in = torch.cat((img, self.label_emb(labels)), -1)
validity = self.model(d_in)
return validity
# 损失函数
adversarial_loss = nn.BCELoss()
auxiliary_loss = nn.CrossEntropyLoss()
# 初始化生成器和判别器
generator = Generator().to(device)
discriminator = Discriminator().to(device)
# 优化器
optimizer_G = optim.Adam(generator.parameters(), lr=lr)
optimizer_D = optim.Adam(discriminator.parameters(), lr=lr)
# 动漫头像数据集
anime_data = AnimeDataset('anime', transform=transform)
dataloader = DataLoader(anime_data, batch_size=batch_size, shuffle=True)
# 训练模型
for epoch in range(num_epochs):
for i, imgs in enumerate(dataloader):
# 真实图像标签为1
valid = torch.ones((imgs.size(0), 1)).to(device)
# 假的图像标签为0
fake = torch.zeros((imgs.size(0), 1)).to(device)
# 训练判别器
optimizer_D.zero_grad()
# 真实图像损失
real_imgs = imgs.to(device)
real_labels = torch.randint(0, num_classes, (imgs.size(0),)).to(device)
d_loss_real = adversarial_loss(discriminator(real_imgs, real_labels), valid)
# 生成器生成的图像损失
noise = torch.randn((imgs.size(0), latent_dim)).to(device)
fake_labels = torch.randint(0, num_classes, (imgs.size(0),)).to(device)
fake_imgs = generator(noise, fake_labels).detach()
d_loss_fake = adversarial_loss(discriminator(fake_imgs, fake_labels), fake)
# 总损失
d_loss = 0.5 * (d_loss_real + d_loss_fake)
d_loss.backward()
optimizer_D.step()
# 训练生成器
optimizer_G.zero_grad()
# 生成器生成的图像损失
noise = torch.randn((imgs.size(0), latent_dim)).to(device)
fake_labels = torch.randint(0, num_classes, (imgs.size(0),)).to(device)
fake_imgs = generator(noise, fake_labels)
g_loss = adversarial_loss(discriminator(fake_imgs, fake_labels), valid)
g_loss.backward()
optimizer_G.step()
if i % 50 == 0:
print('[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]' % (epoch, num_epochs, i, len(dataloader), d_loss.item(), g_loss.item()))
# 保存生成器的图像
if epoch % 10 == 0:
save_image(fake_imgs.data[:25], 'images/%d.png' % epoch, nrow=5, normalize=True)
```
您可以将上述代码保存为“acgan.py”文件并在命令行中运行以下命令以训练模型:
```
python acgan.py
```
注意:训练可能需要一段时间,具体取决于您的计算机性能。您可以通过调整超参数来加速训练,例如减少批量大小或减少训练时期。同时,您还可以在训练过程中查看生成的图像,这些图像将保存在名为“images”的文件夹中。
阅读全文