生成对抗网络代码
时间: 2023-07-21 16:15:07 浏览: 88
生成对抗网络(GAN)是一种深度学习模型,由生成器和判别器两个部分组成,它们通过对抗的方式协同工作,达到生成高质量的数据样本的目的。以下是一个简单的GAN代码示例:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
# 定义生成器
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Generator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
self.tanh = nn.Tanh()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.tanh(self.fc3(x))
return x
# 定义判别器
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.sigmoid(self.fc3(x))
return x
# 定义超参数
input_size = 100
hidden_size = 128
output_size = 1
batch_size = 64
num_epochs = 50
lr = 0.0002
# 定义数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
# 初始化生成器和判别器
G = Generator(input_size, hidden_size, output_size)
D = Discriminator(input_size, hidden_size, output_size)
# 定义损失函数和优化器
criterion = nn.BCELoss()
G_optimizer = optim.Adam(G.parameters(), lr=lr)
D_optimizer = optim.Adam(D.parameters(), lr=lr)
# 训练模型
for epoch in range(num_epochs):
for i, (real_images, _) in enumerate(train_loader):
# 训练判别器
D.zero_grad()
real_labels = torch.ones(batch_size, 1)
fake_labels = torch.zeros(batch_size, 1)
real_outputs = D(real_images.view(batch_size, -1))
real_loss = criterion(real_outputs, real_labels)
z = torch.randn(batch_size, input_size)
fake_images = G(z)
fake_outputs = D(fake_images)
fake_loss = criterion(fake_outputs, fake_labels)
D_loss = real_loss + fake_loss
D_loss.backward()
D_optimizer.step()
# 训练生成器
G.zero_grad()
z = torch.randn(batch_size, input_size)
fake_images = G(z)
fake_outputs = D(fake_images)
G_loss = criterion(fake_outputs, real_labels)
G_loss.backward()
G_optimizer.step()
# 打印损失值
if i % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], D Loss: {:.4f}, G Loss: {:.4f}'
.format(epoch, num_epochs, i, len(train_loader), D_loss.item(), G_loss.item()))
```
在这个例子中,我们使用了PyTorch框架来实现GAN模型。首先定义了生成器和判别器的神经网络结构,然后定义了超参数和数据预处理方式。接着,我们加载MNIST数据集并迭代训练模型,其中每个迭代步骤都包括训练判别器和生成器。最后,我们在每个epoch中打印损失值以跟踪模型的训练过程。
阅读全文