生成对抗网络pytorch代码
时间: 2023-05-25 17:01:37 浏览: 132
A:
以下是生成对抗网络的PyTorch代码示例:
```python
import torch
import torch.nn as nn
import numpy as np
# 定义生成器
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Generator, self).__init__()
self.generator = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, output_size)
)
def forward(self, x):
x = self.generator(x)
return x
# 定义鉴别器
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Discriminator, self).__init__()
self.discriminator = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Linear(hidden_size, output_size),
nn.Sigmoid()
)
def forward(self, x):
x = self.discriminator(x)
return x
# 定义超参
batch_size = 100
learning_rate = 0.002
latent_size = 100
hidden_size = 256
image_size = 28*28
num_epochs = 200
# 加载数据
def load_mnist_data():
from torchvision.datasets import MNIST
from torchvision import transforms
train_data = MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
test_data = MNIST(root='./data', train=False, transform=transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=True)
return train_loader, test_loader
train_loader, _ = load_mnist_data()
# 初始化生成器和鉴别器
generator = Generator(latent_size, hidden_size, image_size)
discriminator = Discriminator(image_size, hidden_size, 1)
# 定义优化器和损失函数
G_optimizer = torch.optim.Adam(generator.parameters(), lr=learning_rate)
D_optimizer = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)
criterion = nn.BCELoss()
# 定义训练过程
def train_GAN(generator, discriminator, G_optimizer, D_optimizer, criterion, num_epochs, device):
generator.to(device)
discriminator.to(device)
generator.train()
discriminator.train()
for epoch in range(num_epochs):
for idx, (real_data, _) in enumerate(train_loader):
real_data = real_data.view(-1, image_size).to(device)
# 训练鉴别器
noise = torch.randn(batch_size, latent_size).to(device)
fake_data = generator(noise)
real_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)
D_optimizer.zero_grad()
D_real_outputs = discriminator(real_data)
D_real_loss = criterion(D_real_outputs, real_labels)
D_fake_outputs = discriminator(fake_data)
D_fake_loss = criterion(D_fake_outputs, fake_labels)
D_loss = D_real_loss + D_fake_loss
D_loss.backward()
D_optimizer.step()
# 训练生成器
noise = torch.randn(batch_size, latent_size).to(device)
fake_data = generator(noise)
G_optimizer.zero_grad()
G_outputs = discriminator(fake_data)
G_loss = criterion(G_outputs, real_labels)
G_loss.backward()
G_optimizer.step()
if idx % 50 == 0:
print('Epoch [{}/{}], Step [{}/{}], D_loss: {:.4f}, G_loss: {:.4f}'
.format(epoch+1, num_epochs, idx+1, len(train_loader), D_loss.item(), G_loss.item()))
# 运行训练
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_GAN(generator, discriminator, G_optimizer, D_optimizer, criterion, num_epochs, device)
```
阅读全文