pytorch 搭建gan
时间: 2023-05-16 16:06:57 浏览: 248
生成对抗网络(GAN)是一种深度学习模型,它由两个神经网络组成:生成器和判别器。PyTorch 是一个流行的深度学习框架,可以用于搭建 GAN 模型。以下是一个简单的 PyTorch GAN 模型的代码示例:
```python
import torch
import torch.nn as nn
# 定义生成器模型
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Generator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.relu(x)
x = self.fc3(x)
return x
# 定义判别器模型
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = self.fc1(x)
x = self.sigmoid(x)
x = self.fc2(x)
x = self.sigmoid(x)
x = self.fc3(x)
x = self.sigmoid(x)
return x
# 定义训练函数
def train_gan(generator, discriminator, data, num_epochs, batch_size, learning_rate):
criterion = nn.BCELoss()
optimizer_g = torch.optim.Adam(generator.parameters(), lr=learning_rate)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=learning_rate)
for epoch in range(num_epochs):
for i in range(0, len(data), batch_size):
# 训练判别器
discriminator.zero_grad()
real_data = data[i:i+batch_size]
real_labels = torch.ones(len(real_data), 1)
fake_data = generator(torch.randn(len(real_data), 100))
fake_labels = torch.zeros(len(real_data), 1)
real_outputs = discriminator(real_data)
fake_outputs = discriminator(fake_data)
loss_d = criterion(real_outputs, real_labels) + criterion(fake_outputs, fake_labels)
loss_d.backward()
optimizer_d.step()
# 训练生成器
generator.zero_grad()
fake_data = generator(torch.randn(len(real_data), 100))
fake_labels = torch.ones(len(real_data), 1)
fake_outputs = discriminator(fake_data)
loss_g = criterion(fake_outputs, fake_labels)
loss_g.backward()
optimizer_g.step()
print('Epoch [{}/{}], Loss D: {:.4f}, Loss G: {:.4f}'.format(epoch+1, num_epochs, loss_d.item(), loss_g.item()))
# 使用示例
data = torch.randn(1000, 10)
generator = Generator(100, 128, 10)
discriminator = Discriminator(10, 128, 1)
train_gan(generator, discriminator, data, num_epochs=100, batch_size=64, learning_rate=0.001)
```
这个示例中,我们定义了一个简单的生成器和判别器模型,并使用 PyTorch 的 BCELoss 作为损失函数。在训练过程中,我们首先训练判别器,然后训练生成器,交替进行。最终,我们可以得到一个可以生成与输入数据相似的新数据的生成器模型。
阅读全文