gan pytorch代码讲解
时间: 2023-09-23 12:13:30 浏览: 105
GAN(Generative Adversarial Networks)是一种生成模型,由 Goodfellow 等人在 2014 年提出。GAN 由一个生成器和一个判别器组成,生成器用于生成假样本,判别器用于区分真实样本和假样本。生成器和判别器通过对抗学习的方式进行训练,最终生成器可以生成与真实数据相似的样本。
下面是一个使用 PyTorch 实现的 GAN 模型的代码讲解。
首先,我们需要导入必要的库。
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
```
接着,我们定义生成器和判别器的结构。这里我们使用的是简单的全连接网络。
```python
class Generator(nn.Module):
def __init__(self, z_dim, hidden_dim, img_dim):
super(Generator, self).__init__()
self.z_dim = z_dim
self.hidden_dim = hidden_dim
self.img_dim = img_dim
self.fc1 = nn.Linear(z_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, img_dim)
def forward(self, x):
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = torch.tanh(self.fc3(x))
return x
class Discriminator(nn.Module):
def __init__(self, img_dim, hidden_dim):
super(Discriminator, self).__init__()
self.img_dim = img_dim
self.hidden_dim = hidden_dim
self.fc1 = nn.Linear(img_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, hidden_dim)
self.fc3 = nn.Linear(hidden_dim, 1)
def forward(self, x):
x = nn.functional.relu(self.fc1(x))
x = nn.functional.relu(self.fc2(x))
x = torch.sigmoid(self.fc3(x))
return x
```
接着,我们需要定义损失函数和优化器。
```python
z_dim = 100
hidden_dim = 128
img_dim = 784
generator = Generator(z_dim, hidden_dim, img_dim)
discriminator = Discriminator(img_dim, hidden_dim)
criterion = nn.BCELoss()
g_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
```
然后,我们需要训练模型。G 和 D 分别进行训练,每次迭代更新 G 和 D 的参数。在每次迭代中,我们先生成一些假样本,然后训练 D 区分真假样本,并更新 D 的参数。接着,我们生成一些假样本,让 D 去判断这些假样本是否为真,然后训练 G,使得生成的假样本越来越接近真实样本。
```python
num_epochs = 20
batch_size = 100
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
for epoch in range(num_epochs):
for i, (real_images, _) in enumerate(train_loader):
batch_size = real_images.size(0)
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
z = torch.randn(batch_size, z_dim)
fake_images = generator(z)
# Train Discriminator
real_outputs = discriminator(real_images.view(batch_size, -1))
d_loss_real = criterion(real_outputs, real_labels)
fake_outputs = discriminator(fake_images.detach().view(batch_size, -1))
d_loss_fake = criterion(fake_outputs, fake_labels)
d_loss = d_loss_real + d_loss_fake
d_optimizer.zero_grad()
d_loss.backward()
d_optimizer.step()
# Train Generator
z = torch.randn(batch_size, z_dim)
fake_images = generator(z)
fake_outputs = discriminator(fake_images.view(batch_size, -1))
g_loss = criterion(fake_outputs, real_labels)
g_optimizer.zero_grad()
g_loss.backward()
g_optimizer.step()
if (i+1) % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}'
.format(epoch+1, num_epochs, i+1, len(train_loader), d_loss.item(), g_loss.item()))
# Generate some fake images
with torch.no_grad():
z = torch.randn(16, z_dim)
fake_images = generator(z)
fake_images = fake_images.view(-1, 28, 28)
plt.imshow(fake_images[0].cpu(), cmap='gray')
plt.show()
```
最后,我们可以使用生成器生成一些假样本,并查看生成的假样本是否与真实样本相似。
```python
with torch.no_grad():
z = torch.randn(16, z_dim)
fake_images = generator(z)
fake_images = fake_images.view(-1, 28, 28)
plt.imshow(fake_images[0].cpu(), cmap='gray')
plt.show()
```
这就是使用 PyTorch 实现 GAN 的完整代码和讲解。
阅读全文