WGAN自动生成动漫头像PyTorch 代码
时间: 2024-05-14 18:18:44 浏览: 179
以下是使用 PyTorch 实现 WGAN 生成动漫头像的代码示例:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter
from torchvision.datasets import ImageFolder
from torchvision.utils import save_image
# 定义超参数
batch_size = 64
num_epochs = 200
z_dim = 100
lr = 0.0002
beta1 = 0.5
beta2 = 0.999
critic_iter = 5
clamp_value = 0.01
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 定义数据预处理
transform = transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
# 加载数据集
dataset = ImageFolder(root='./data', transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
# 定义生成器
class Generator(nn.Module):
def __init__(self, z_dim):
super(Generator, self).__init__()
self.fc = nn.Sequential(
nn.Linear(z_dim, 1024),
nn.BatchNorm1d(1024),
nn.ReLU(),
nn.Linear(1024, 8 * 8 * 256),
nn.BatchNorm1d(8 * 8 * 256),
nn.ReLU()
)
self.conv = nn.Sequential(
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.ConvTranspose2d(64, 3, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
def forward(self, z):
x = self.fc(z)
x = x.view(-1, 256, 8, 8)
x = self.conv(x)
return x
# 定义判别器
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True)
)
self.fc = nn.Sequential(
nn.Linear(256 * 8 * 8, 1),
nn.Sigmoid()
)
def forward(self, x):
x = self.conv(x)
x = x.view(-1, 256 * 8 * 8)
x = self.fc(x)
return x
# 初始化生成器和判别器
G = Generator(z_dim).to(device)
D = Discriminator().to(device)
# 定义优化器
G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(beta1, beta2))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(beta1, beta2))
# 训练模型
total_step = len(dataloader)
writer = SummaryWriter()
for epoch in range(num_epochs):
for i, (images, _) in enumerate(dataloader):
images = images.to(device)
# 训练判别器
for critic_step in range(critic_iter):
z = torch.randn(batch_size, z_dim).to(device)
fake_images = G(z)
real_scores = D(images)
fake_scores = D(fake_images.detach())
D_loss = -(torch.mean(real_scores) - torch.mean(fake_scores))
D_optimizer.zero_grad()
D_loss.backward()
D_optimizer.step()
# 限制判别器的参数范围
for p in D.parameters():
p.data.clamp_(-clamp_value, clamp_value)
# 训练生成器
z = torch.randn(batch_size, z_dim).to(device)
fake_images = G(z)
fake_scores = D(fake_images)
G_loss = -torch.mean(fake_scores)
G_optimizer.zero_grad()
G_loss.backward()
G_optimizer.step()
# 记录损失
writer.add_scalar('D loss', D_loss.item(), epoch * total_step + i)
writer.add_scalar('G loss', G_loss.item(), epoch * total_step + i)
# 输出训练状态
print('Epoch [{}/{}], Step [{}/{}], D Loss: {:.4f}, G Loss: {:.4f}'
.format(epoch + 1, num_epochs, i + 1, total_step, D_loss.item(), G_loss.item()))
# 保存生成的图像
z = torch.randn(64, z_dim).to(device)
fake_images = G(z)
save_image(fake_images, 'images/{}.png'.format(epoch + 1), nrow=8, normalize=True)
writer.close()
```
在这个示例代码中,我们使用了 WGAN-GP (Wasserstein GAN with gradient penalty)来训练生成器和判别器。WGAN-GP 是对 WGAN 的改进,它通过添加梯度惩罚来强制判别器满足 Lipschitz 连续性,从而提高训练稳定性和生成图像质量。具体实现中,我们使用了 PyTorch 的优化器和数据加载器,并使用 TensorBoard 记录训练过程中的损失和生成图像。
阅读全文