pytorch 生成对抗网络
时间: 2023-09-10 12:11:04 浏览: 162
PyTorch是一个非常流行的深度学习框架,可以用来实现各种类型的神经网络,包括生成对抗网络(GANs)。GAN是一种由两个深度神经网络组成的模型,一个生成器和一个判别器,用于生成逼真的图像或数据。
在PyTorch中,可以使用`torch.nn`模块中的类来定义生成器和判别器的模型结构,使用`torch.optim`模块中的类来定义优化器,并使用`torch.utils.data`模块中的类来加载数据集。然后,可以使用PyTorch的自动微分功能来计算损失并进行反向传播。
下面是一个简单的PyTorch GAN实现的示例代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import transforms
# 定义生成器模型
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
self.fc1 = nn.Linear(100, 256)
self.fc2 = nn.Linear(256, 512)
self.fc3 = nn.Linear(512, 28*28)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.tanh(self.fc3(x))
return x.view(-1, 1, 28, 28)
# 定义判别器模型
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
self.conv1 = nn.Conv2d(1, 16, kernel_size=5)
self.conv2 = nn.Conv2d(16, 32, kernel_size=5)
self.fc1 = nn.Linear(4*4*32, 1)
def forward(self, x):
x = torch.relu(self.conv1(x))
x = torch.max_pool2d(x, 2)
x = torch.relu(self.conv2(x))
x = torch.max_pool2d(x, 2)
x = x.view(-1, 4*4*32)
x = torch.sigmoid(self.fc1(x))
return x
# 定义损失函数和优化器
criterion = nn.BCELoss()
gen_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
disc_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
# 加载数据集
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = DataLoader(train_data, batch_size=128, shuffle=True)
# 训练GAN模型
for epoch in range(num_epochs):
for i, (real_images, _) in enumerate(train_loader):
# 训练判别器模型
disc_optimizer.zero_grad()
real_labels = torch.ones(real_images.size(0), 1)
fake_labels = torch.zeros(real_images.size(0), 1)
real_outputs = discriminator(real_images)
real_loss = criterion(real_outputs, real_labels)
noise = torch.randn(real_images.size(0), 100)
fake_images = generator(noise)
fake_outputs = discriminator(fake_images.detach())
fake_loss = criterion(fake_outputs, fake_labels)
disc_loss = real_loss + fake_loss
disc_loss.backward()
disc_optimizer.step()
# 训练生成器模型
gen_optimizer.zero_grad()
noise = torch.randn(real_images.size(0), 100)
fake_images = generator(noise)
fake_outputs = discriminator(fake_images)
gen_loss = criterion(fake_outputs, real_labels)
gen_loss.backward()
gen_optimizer.step()
# 打印损失
if i % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], Discriminator Loss: {:.4f}, Generator Loss: {:.4f}'
.format(epoch, num_epochs, i, len(train_loader), disc_loss.item(), gen_loss.item()))
```
在这个例子中,我们使用MNIST数据集来训练一个简单的GAN模型,其中生成器模型接受一个随机噪声向量作为输入,并输出一个28x28像素的图像。判别器模型接受一个图像作为输入,并输出一个二进制值,表示该图像是真实的还是虚假的。训练过程中,我们交替训练生成器和判别器模型,并计算损失和优化模型参数。
阅读全文