pytorchGAN模型代码
时间: 2023-09-22 20:06:44 浏览: 134
Pytorch实现数字对抗样本生成全套代码(GAN)
5星 · 资源好评率100%
以下是一个简单的基于PyTorch的GAN模型代码示例:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
# 定义生成器模型
class Generator(nn.Module):
def __init__(self, nz=100, ngf=64, nc=3):
super(Generator, self).__init__()
self.main = nn.Sequential(
nn.ConvTranspose2d(nz, ngf*8, 4, 1, 0, bias=False),
nn.BatchNorm2d(ngf*8),
nn.ReLU(True),
nn.ConvTranspose2d(ngf*8, ngf*4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf*4),
nn.ReLU(True),
nn.ConvTranspose2d(ngf*4, ngf*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf*2),
nn.ReLU(True),
nn.ConvTranspose2d(ngf*2, ngf, 4, 2, 1, bias=False),
nn.BatchNorm2d(ngf),
nn.ReLU(True),
nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
nn.Tanh()
)
def forward(self, x):
return self.main(x)
# 定义判别器模型
class Discriminator(nn.Module):
def __init__(self, ndf=64, nc=3):
super(Discriminator, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf, ndf*2, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf*2),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf*2, ndf*4, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf*4),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf*4, ndf*8, 4, 2, 1, bias=False),
nn.BatchNorm2d(ndf*8),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(ndf*8, 1, 4, 1, 0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x).view(-1, 1)
# 定义训练函数
def train(dataloader, generator, discriminator, optimizer_g, optimizer_d, criterion, device):
for epoch in range(num_epochs):
for i, data in enumerate(tqdm(dataloader)):
real_images = data[0].to(device)
batch_size = real_images.size(0)
# 训练判别器
optimizer_d.zero_grad()
# 训练判别器鉴别真实图像
label_real = torch.full((batch_size,), 1, device=device)
output_real = discriminator(real_images)
loss_d_real = criterion(output_real, label_real)
loss_d_real.backward()
# 训练判别器鉴别生成图像
noise = torch.randn(batch_size, nz, 1, 1, device=device)
fake_images = generator(noise)
label_fake = torch.full((batch_size,), 0, device=device)
output_fake = discriminator(fake_images.detach())
loss_d_fake = criterion(output_fake, label_fake)
loss_d_fake.backward()
# 总判别器损失
loss_d = loss_d_real + loss_d_fake
optimizer_d.step()
# 训练生成器
optimizer_g.zero_grad()
# 训练生成器欺骗判别器
label_real = torch.full((batch_size,), 1, device=device)
output_fake = discriminator(fake_images)
loss_g = criterion(output_fake, label_real)
loss_g.backward()
optimizer_g.step()
# 每个epoch结束后输出loss和生成的图像
print(f"Epoch {epoch+1}/{num_epochs}, Loss_D: {loss_d.item()}, Loss_G: {loss_g.item()}")
fake_images = generator(fixed_noise)
save_image(fake_images.detach().cpu(), f"gan_images/epoch_{epoch+1}.png", normalize=True)
# 设置超参数
num_epochs = 200
batch_size = 64
lr = 0.0002
beta1 = 0.5
nz = 100
ngf = 64
ndf = 64
nc = 3
# 加载数据集
transform = transforms.Compose([
transforms.Resize(64),
transforms.CenterCrop(64),
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
dataset = datasets.ImageFolder("data/", transform=transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)
# 初始化生成器和判别器
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(nz=nz, ngf=ngf, nc=nc).to(device)
discriminator = Discriminator(ndf=ndf, nc=nc).to(device)
# 设置优化器和损失函数
optimizer_g = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))
criterion = nn.BCELoss()
# 固定生成器输入噪声,用于输出生成的图像
fixed_noise = torch.randn(64, nz, 1, 1, device=device)
# 训练模型
train(dataloader, generator, discriminator, optimizer_g, optimizer_d, criterion, device)
```
以上代码演示了如何使用PyTorch实现一个简单的DCGAN。在训练过程中,我们使用交替优化的方式训练生成器和判别器,其中生成器的目标是欺骗判别器,而判别器的目标是尽可能地区分真实图像和生成图像。
阅读全文