WGAN自动生成动漫头像代码用PyTorch 实现
时间: 2024-05-13 15:16:33 浏览: 124
WGAN(Wasserstein GAN)是GAN(Generative Adversarial Network)的一种改进模型,它通过使用Wasserstein距离替代JS散度来解决GAN中的训练不稳定问题,从而提高了生成器和判别器的训练效果。在这里,我将介绍如何使用PyTorch实现WGAN来生成动漫头像。
首先,我们需要导入必要的库:
```python
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
import matplotlib.pyplot as plt
import numpy as np
```
接下来,我们定义一些超参数:
```python
batch_size = 64 # 批次大小
n_epochs = 200 # 训练轮数
z_dim = 100 # 噪声维度
lr = 0.00005 # 学习率
clip_value = 0.01 # 截断值
n_critic = 5 # 判别器训练次数
```
然后,我们定义生成器和判别器的网络结构:
```python
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc = nn.Sequential(
nn.Linear(z_dim, 4*4*512),
nn.BatchNorm1d(4*4*512)
)
self.conv = nn.Sequential(
nn.ConvTranspose2d(512, 256, 4, 2, 1),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, 4, 2, 1),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, 3, 4, 2, 1),
nn.Tanh()
)
def forward(self, z):
x = self.fc(z)
x = x.view(-1, 512, 4, 4)
x = self.conv(x)
return x
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 64, 4, 2, 1),
nn.LeakyReLU(0.2),
nn.Conv2d(64, 128, 4, 2, 1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2),
nn.Conv2d(128, 256, 4, 2, 1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2),
nn.Conv2d(256, 512, 4, 2, 1),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2)
)
self.fc = nn.Linear(4*4*512, 1)
def forward(self, x):
x = self.conv(x)
x = x.view(-1, 4*4*512)
x = self.fc(x)
return x
```
接下来,我们定义WGAN模型:
```python
class WGAN(object):
def __init__(self):
self.generator = Generator()
self.discriminator = Discriminator()
self.generator.cuda()
self.discriminator.cuda()
self.optimizer_g = torch.optim.RMSprop(self.generator.parameters(), lr=lr)
self.optimizer_d = torch.optim.RMSprop(self.discriminator.parameters(), lr=lr)
self.loss_fn = nn.MSELoss()
def train(self, data_loader):
total_step = len(data_loader)
for epoch in range(n_epochs):
for i, (images, _) in enumerate(data_loader):
# 训练判别器
for j in range(n_critic):
images = images.cuda()
z = torch.randn(batch_size, z_dim).cuda()
fake_images = self.generator(z)
real_out = self.discriminator(images)
fake_out = self.discriminator(fake_images.detach())
loss_d = -torch.mean(real_out) + torch.mean(fake_out)
self.optimizer_d.zero_grad()
loss_d.backward()
self.optimizer_d.step()
# 截断判别器的参数
for p in self.discriminator.parameters():
p.data.clamp_(-clip_value, clip_value)
# 训练生成器
z = torch.randn(batch_size, z_dim).cuda()
fake_images = self.generator(z)
fake_out = self.discriminator(fake_images)
loss_g = -torch.mean(fake_out)
self.optimizer_g.zero_grad()
loss_g.backward()
self.optimizer_g.step()
if (i+1) % 10 == 0:
print ('Epoch [{}/{}], Step [{}/{}], Loss_D: {:.4f}, Loss_G: {:.4f}'
.format(epoch+1, n_epochs, i+1, total_step, loss_d.item(), loss_g.item()))
# 保存生成的图片
with torch.no_grad():
fake_images = self.generator(z)
fake_images = fake_images.view(-1, 3, 64, 64)
save_image(fake_images, 'generated_images-{}.png'.format(epoch+1))
```
最后,我们加载动漫头像数据集,并训练WGAN模型:
```python
# 加载数据集
transform = transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
dataset = ImageFolder('./data', transform)
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# 训练WGAN模型
wgan = WGAN()
wgan.train(data_loader)
```
训练完成后,我们可以使用生成器生成一些动漫头像:
```python
# 加载训练好的生成器
generator = Generator()
generator.load_state_dict(torch.load('./generator.pth'))
# 生成动漫头像
z = torch.randn(64, z_dim).cuda()
fake_images = generator(z)
fake_images = fake_images.view(-1, 3, 64, 64)
for i in range(64):
plt.subplot(8, 8, i+1)
plt.imshow((fake_images[i].cpu().detach().numpy().transpose(1, 2, 0)+1)/2)
plt.axis('off')
plt.show()
```
至此,我们就完成了使用PyTorch实现WGAN自动生成动漫头像的代码。
阅读全文