GAN pytorch 代码
时间: 2023-10-02 15:09:44 浏览: 109
GAN(生成对抗网络)是一种深度学习模型,用于生成具有逼真度的新样本数据。下面是一个使用PyTorch实现的GAN代码的示例:
```python
import torch
import torch.nn as nn
# 定义生成器网络
class Generator(nn.Module):
def __init__(self, input_dim, output_dim):
super(Generator, self).__init__()
self.fc = nn.Linear(input_dim, 64)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(64, output_dim)
self.tanh = nn.Tanh()
def forward(self, x):
x = self.fc(x)
x = self.relu(x)
x = self.fc2(x)
x = self.tanh(x)
return x
# 定义判别器网络
class Discriminator(nn.Module):
def __init__(self, input_dim):
super(Discriminator, self).__init__()
self.fc = nn.Linear(input_dim, 64)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(64, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.fc(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return x
# 定义训练过程
def train_gan(generator, discriminator, data_loader, num_epochs, lr):
criterion = nn.BCELoss()
optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr)
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr)
for epoch in range(num_epochs):
for real_data in data_loader:
# 训练判别器
optimizer_D.zero_grad()
real_labels = torch.ones(real_data.size(0), 1)
fake_labels = torch.zeros(real_data.size(0), 1)
real_output = discriminator(real_data)
real_loss = criterion(real_output, real_labels)
real_loss.backward()
z = torch.randn(real_data.size(0), input_dim)
fake_data = generator(z)
fake_output = discriminator(fake_data.detach())
fake_loss = criterion(fake_output, fake_labels)
fake_loss.backward()
optimizer_D.step()
# 训练生成器
optimizer_G.zero_grad()
fake_output = discriminator(fake_data)
gen_loss = criterion(fake_output, real_labels)
gen_loss.backward()
optimizer_G.step()
if (epoch+1) % 10 == 0:
print(f"Epoch [{epoch+1}/{num_epochs}], G_Loss: {gen_loss.item():.4f}, D_Loss: {real_loss.item()+fake_loss.item():.4f}")
# 示例数据准备
input_dim = 100
output_dim = 784
batch_size = 100
data_loader = torch.utils.data.DataLoader(
dataset=torchvision.datasets.MNIST(root='./data', train=True, transform=torchvision.transforms.ToTensor(), download=True),
batch_size=batch_size, shuffle=True)
# 创建生成器和判别器实例
generator = Generator(input_dim, output_dim)
discriminator = Discriminator(output_dim)
# 训练GAN模型
train_gan(generator, discriminator, data_loader, num_epochs=100, lr=0.0002)
```
阅读全文