WGAN自动生成动漫头像PyTorch 代码
时间: 2024-05-16 11:14:12 浏览: 164
WGAN生成对抗网络训练Pytorch代码 使用MNIST数据集生成数字图片
5星 · 资源好评率100%
以下是一个使用WGAN生成动漫头像的PyTorch代码示例:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torch.autograd import Variable
import numpy as np
import os
from PIL import Image
# 定义超参数
batch_size = 64
lr = 0.00005
z_dim = 100
clip_value = 0.01
n_critic = 5
epochs = 200
save_dir = 'generated_images'
# 定义生成器
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.ConvTranspose2d(in_channels=z_dim, out_channels=256, kernel_size=4, stride=1, padding=0),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels=256, out_channels=128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels=128, out_channels=64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.ConvTranspose2d(in_channels=64, out_channels=3, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
def forward(self, x):
x = self.model(x)
return x
# 定义判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(in_channels=256, out_channels=1, kernel_size=4, stride=1, padding=0)
)
def forward(self, x):
x = self.model(x)
return x
# 加载数据集
transform = transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
dataset = datasets.ImageFolder(root='data', transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
# 初始化生成器和判别器
generator = Generator()
discriminator = Discriminator()
# 定义优化器
optimizer_G = optim.RMSprop(generator.parameters(), lr=lr)
optimizer_D = optim.RMSprop(discriminator.parameters(), lr=lr)
# 开始训练
for epoch in range(epochs):
for i, data in enumerate(dataloader):
real_images, _ = data
batch_size = real_images.size(0)
# 定义真实数据和噪声数据
real_images = Variable(real_images)
z = Variable(torch.randn(batch_size, z_dim, 1, 1))
# 训练判别器
for j in range(n_critic):
discriminator.zero_grad()
# 计算判别器的损失
d_loss = torch.mean(discriminator(real_images)) - torch.mean(discriminator(generator(z)))
# 计算梯度惩罚项
alpha = torch.rand(batch_size, 1, 1, 1)
alpha = alpha.expand_as(real_images)
alpha = alpha.cuda() if torch.cuda.is_available() else alpha
interpolates = alpha * real_images + ((1 - alpha) * generator(z)).detach()
interpolates = interpolates.cuda() if torch.cuda.is_available() else interpolates
interpolates = Variable(interpolates, requires_grad=True)
d_interpolates = discriminator(interpolates)
gradients = torch.autograd.grad(outputs=d_interpolates, inputs=interpolates,
grad_outputs=torch.ones(d_interpolates.size()).cuda() if torch.cuda.is_available() else torch.ones(
d_interpolates.size()),
create_graph=True, retain_graph=True, only_inputs=True)[0]
gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=(1, 2, 3)))
gradient_penalty = ((gradients_norm - 1) ** 2).mean() * 10
# 计算判别器的总损失
d_loss += gradient_penalty
d_loss.backward()
optimizer_D.step()
# 截断判别器的权值
for p in discriminator.parameters():
p.data.clamp_(-clip_value, clip_value)
# 训练生成器
generator.zero_grad()
g_loss = -torch.mean(discriminator(generator(z)))
g_loss.backward()
optimizer_G.step()
# 输出损失
if i % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}'
.format(epoch, epochs, i, len(dataloader), d_loss.item(), g_loss.item()))
# 保存生成的图像
if not os.path.exists(save_dir):
os.makedirs(save_dir)
fake_images = generator(z).detach().cpu()
for k in range(fake_images.size(0)):
fake_image = fake_images[k].squeeze().numpy()
fake_image = np.transpose(fake_image, (1, 2, 0))
fake_image = (fake_image + 1) / 2
fake_image = (fake_image * 255).astype(np.uint8)
fake_image = Image.fromarray(fake_image)
fake_image.save(os.path.join(save_dir, 'fake_image_{}_{}.png'.format(epoch, k)))
```
在这个示例中,我们使用Wasserstein GAN(WGAN)来生成动漫头像。WGAN是GAN的一种变体,它试图通过使用Wasserstein距离(也称为Earth-Mover距离)来解决训练过程中的不稳定性问题。我们还使用了梯度惩罚技术来强制判别器满足Lipschitz连续性。我们定义了一个简单的生成器和判别器,然后使用RMSprop优化器来训练它们。在训练过程中,我们还保存了生成的图像。
阅读全文