gan代码python
时间: 2023-11-09 17:02:08 浏览: 77
GAN神经网络的python代码
以下是一个简单的 GAN 代码示例,使用 PyTorch 实现:
```python
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# 定义生成器网络
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Generator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
self.relu = nn.ReLU()
def forward(self, x):
out = self.relu(self.fc1(x))
out = self.relu(self.fc2(out))
out = self.fc3(out)
return out
# 定义判别器网络
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
out = self.sigmoid(self.fc1(x))
out = self.sigmoid(self.fc2(out))
out = self.sigmoid(self.fc3(out))
return out
# 定义超参数
input_size = 100
hidden_size = 128
output_size = 1
num_epochs = 2000
batch_size = 64
lr = 0.0002
# 初始化生成器和判别器
G = Generator(input_size, hidden_size, output_size)
D = Discriminator(output_size, hidden_size, output_size)
# 定义损失函数和优化器
criterion = nn.BCELoss()
G_optimizer = torch.optim.Adam(G.parameters(), lr=lr)
D_optimizer = torch.optim.Adam(D.parameters(), lr=lr)
# 训练模型
for epoch in range(num_epochs):
for i in range(0, 10000, batch_size):
# 生成真实数据
real_data = torch.Tensor(np.random.normal(0, 1, (batch_size, input_size)))
# 生成假数据
fake_data = G(torch.randn(batch_size, input_size))
# 训练判别器
D_real = D(real_data)
D_fake = D(fake_data)
D_loss = criterion(D_real, torch.ones(batch_size, 1)) + criterion(D_fake, torch.zeros(batch_size, 1))
D_optimizer.zero_grad()
D_loss.backward()
D_optimizer.step()
# 训练生成器
fake_data = G(torch.randn(batch_size, input_size))
D_fake = D(fake_data)
G_loss = criterion(D_fake, torch.ones(batch_size, 1))
G_optimizer.zero_grad()
G_loss.backward()
G_optimizer.step()
# 打印损失值
if (epoch+1) % 100 == 0:
print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f}'
.format(epoch+1, num_epochs, D_loss.item(), G_loss.item()))
# 生成数据并可视化
fake_data = G(torch.randn(1000, input_size)).detach().numpy()
plt.hist(fake_data, bins=50)
plt.show()
```
阅读全文