pytorch实现GAN
时间: 2023-10-23 09:23:00 浏览: 94
Pytorch 实现GAN对抗学习.zip
5星 · 资源好评率100%
GAN(Generative Adversarial Networks)是一种生成模型,它通过训练一个生成器网络和一个判别器网络来学习对输入数据进行生成。下面是使用PyTorch实现GAN的基本步骤:
1. 导入必要的库和数据集
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# 加载MNIST数据集
train_data = datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True)
```
2. 定义生成器和判别器模型
```python
# 定义生成器模型
class Generator(nn.Module):
def __init__(self, latent_dim, img_shape):
super(Generator, self).__init__()
self.latent_dim = latent_dim
self.img_shape = img_shape
self.model = nn.Sequential(
nn.Linear(self.latent_dim, 128),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(128, 256),
nn.BatchNorm1d(256, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 512),
nn.BatchNorm1d(512, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 1024),
nn.BatchNorm1d(1024, 0.8),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(1024, int(torch.prod(torch.tensor(self.img_shape)))),
nn.Tanh()
)
self.model.apply(self.init_weights)
def init_weights(self, m):
if type(m) == nn.Linear:
nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def forward(self, z):
img = self.model(z)
img = img.view(img.size(0), *self.img_shape)
return img
# 定义判别器模型
class Discriminator(nn.Module):
def __init__(self, img_shape):
super(Discriminator, self).__init__()
self.img_shape = img_shape
self.model = nn.Sequential(
nn.Linear(int(torch.prod(torch.tensor(self.img_shape))), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid()
)
self.model.apply(self.init_weights)
def init_weights(self, m):
if type(m) == nn.Linear:
nn.init.xavier_uniform_(m.weight)
m.bias.data.fill_(0.01)
def forward(self, img):
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
```
3. 定义优化器和损失函数
```python
# 定义模型和优化器
latent_dim = 100
img_shape = (1, 28, 28)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
generator = Generator(latent_dim=latent_dim, img_shape=img_shape).to(device)
discriminator = Discriminator(img_shape=img_shape).to(device)
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
adversarial_loss = nn.BCELoss()
```
4. 训练模型
```python
# 训练模型
num_epochs = 200
batch_size = 64
dataloader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
for epoch in range(num_epochs):
for i, (imgs, _) in enumerate(dataloader):
# 训练判别器
optimizer_D.zero_grad()
z = torch.randn(imgs.shape[0], latent_dim).to(device)
gen_imgs = generator(z)
real_imgs = imgs.to(device)
valid = torch.ones((real_imgs.size(0), 1)).to(device)
fake = torch.zeros((gen_imgs.size(0), 1)).to(device)
loss_D_real = adversarial_loss(discriminator(real_imgs), valid)
loss_D_fake = adversarial_loss(discriminator(gen_imgs.detach()), fake)
loss_D = (loss_D_real + loss_D_fake) / 2
loss_D.backward()
optimizer_D.step()
# 训练生成器
optimizer_G.zero_grad()
gen_imgs = generator(z)
loss_G = adversarial_loss(discriminator(gen_imgs), valid)
loss_G.backward()
optimizer_G.step()
print("[Epoch %d/%d] [Loss_D: %f] [Loss_G: %f]" % (epoch, num_epochs, loss_D.item(), loss_G.item()))
```
5. 生成新数据
```python
# 生成新数据
num_samples = 100
z = torch.randn(num_samples, latent_dim).to(device)
gen_data = generator(z)
```
这就是使用PyTorch实现GAN的基本步骤。当然,你可以根据需要进行更改和调整。
阅读全文