GAN的pytorch代码
时间: 2023-09-20 14:13:26 浏览: 95
以下是一个简单的GAN的PyTorch代码示例:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
# 定义生成器网络
class Generator(nn.Module):
def __init__(self, input_size, output_size):
super(Generator, self).__init__()
self.fc1 = nn.Linear(input_size, 128)
self.fc2 = nn.Linear(128, 256)
self.fc3 = nn.Linear(256, output_size)
self.relu = nn.ReLU()
def forward(self, x):
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
# 定义判别器网络
class Discriminator(nn.Module):
def __init__(self, input_size):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(input_size, 256)
self.fc2 = nn.Linear(256, 128)
self.fc3 = nn.Linear(128, 1)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.sigmoid(self.fc1(x))
x = self.sigmoid(self.fc2(x))
x = self.sigmoid(self.fc3(x))
return x
# 定义训练函数
def train(num_epochs, generator, discriminator, gen_optimizer, disc_optimizer, data):
for epoch in range(num_epochs):
for i in range(len(data)):
# 训练判别器
real_data = torch.Tensor(data[i])
fake_data = generator(torch.randn(1, 100))
disc_optimizer.zero_grad()
real_pred = discriminator(real_data)
fake_pred = discriminator(fake_data.detach())
real_loss = nn.BCELoss()(real_pred, torch.ones_like(real_pred))
fake_loss = nn.BCELoss()(fake_pred, torch.zeros_like(fake_pred))
disc_loss = real_loss + fake_loss
disc_loss.backward()
disc_optimizer.step()
# 训练生成器
gen_optimizer.zero_grad()
fake_data = generator(torch.randn(1, 100))
fake_pred = discriminator(fake_data)
gen_loss = nn.BCELoss()(fake_pred, torch.ones_like(fake_pred))
gen_loss.backward()
gen_optimizer.step()
# 打印损失
if (i+1) % 100 == 0:
print("[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch+1, num_epochs, i+1, len(data), disc_loss.item(), gen_loss.item()))
# 加载数据
data = np.random.randn(10000, 50)
# 初始化模型和优化器
generator = Generator(100, 50)
discriminator = Discriminator(50)
gen_optimizer = optim.Adam(generator.parameters(), lr=0.0002)
disc_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002)
# 训练模型
train(50, generator, discriminator, gen_optimizer, disc_optimizer, data)
```
这是一个非常简单的GAN,生成器和判别器的网络都是简单的全连接层。在训练过程中,我们首先训练判别器,然后训练生成器,交替进行。每100个批次我们输出一次损失。
阅读全文