pytorch GAN模型
时间: 2023-09-21 14:11:26 浏览: 85
GAN(Generative Adversarial Network)是一种生成式深度学习模型,它由两个神经网络组成:一个生成器网络和一个判别器网络。生成器网络可以生成逼真的图像、文本或音频等,而判别器网络则用于区分生成器生成的图像与真实图像的不同之处。两个网络不断地相互对抗、优化,直到生成的图像与真实图像无法区分。
下面是一个用 PyTorch 实现的简单的 GAN 模型:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image
# 定义生成器网络
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc = nn.Sequential(
nn.Linear(100, 256),
nn.ReLU(),
nn.Linear(256, 512),
nn.ReLU(),
nn.Linear(512, 1024),
nn.ReLU(),
nn.Linear(1024, 784),
nn.Tanh()
)
def forward(self, x):
x = self.fc(x)
return x
# 定义判别器网络
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.fc = nn.Sequential(
nn.Linear(784, 512),
nn.ReLU(),
nn.Linear(512, 256),
nn.ReLU(),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, x):
x = self.fc(x)
return x
# 加载 MNIST 数据集
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=(0.5,), std=(0.5,))
])
train_dataset = dset.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
# 初始化生成器和判别器
G = Generator()
D = Discriminator()
# 定义优化器和损失函数
G_optimizer = optim.Adam(G.parameters(), lr=0.0002, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=0.0002, betas=(0.5, 0.999))
criterion = nn.BCELoss()
# 训练 GAN 模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
G.to(device)
D.to(device)
for epoch in range(100):
for i, (images, _) in enumerate(train_loader):
batch_size = images.size(0)
images = images.view(batch_size, -1).to(device)
# 训练判别器
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
noise = torch.randn(batch_size, 100).to(device)
fake_images = G(noise)
D_real_outputs = D(images)
D_fake_outputs = D(fake_images.detach())
D_real_loss = criterion(D_real_outputs, real_labels)
D_fake_loss = criterion(D_fake_outputs, fake_labels)
D_loss = D_real_loss + D_fake_loss
D_optimizer.zero_grad()
D_loss.backward()
D_optimizer.step()
# 训练生成器
noise = torch.randn(batch_size, 100).to(device)
fake_images = G(noise)
D_fake_outputs = D(fake_images)
G_loss = criterion(D_fake_outputs, real_labels)
G_optimizer.zero_grad()
G_loss.backward()
G_optimizer.step()
# 输出损失值
if i % 100 == 0:
print(f'Epoch [{epoch+1}/{100}] Batch [{i+1}/{len(train_loader)}] D_loss: {D_loss.item():.4f}, G_loss: {G_loss.item():.4f}')
# 保存生成的图像
with torch.no_grad():
noise = torch.randn(64, 100).to(device)
fake_images = G(noise).view(64, 1, 28, 28)
save_image(fake_images, f'./gan_images/{epoch+1}.png')
```
在这个例子中,我们使用了 PyTorch 内置的 MNIST 数据集,并定义了一个含有三个全连接层的生成器网络和一个含有两个全连接层的判别器网络。我们采用了 Adam 优化器和二元交叉熵损失函数。在训练过程中,我们不断地交替训练生成器和判别器,并且每完成一个 epoch 就保存一批生成的图像。
阅读全文