写一个用GAN 生成一维特征的代码pytorch框架
时间: 2023-06-09 09:01:58 浏览: 107
基于pytorch的GAN.py代码
好的,以下是一个简单的用GAN生成一维特征的PyTorch代码:
```python
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
# 用于生成真实数据的函数,本例使用 y = x^2 + 0.1*noise
def get_real_samples(n_samples):
x = np.linspace(-1, 1, n_samples)
y = x ** 2 + 0.1 * np.random.randn(*x.shape)
return x, y
# 用于生成噪音数据的函数
def get_noise_samples(n_samples, noise_dim):
return np.random.randn(n_samples, noise_dim)
# 生成器及其对应的判别器定义
class Generator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Generator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x)
return x
class Discriminator(nn.Module):
def __init__(self, input_size, hidden_size, output_size):
super(Discriminator, self).__init__()
self.fc1 = nn.Linear(input_size, hidden_size)
self.fc2 = nn.Linear(hidden_size, hidden_size)
self.fc3 = nn.Linear(hidden_size, output_size)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.sigmoid(self.fc3(x))
return x
# GAN模型训练
def train_gan(generator, discriminator, gan, n_epochs, n_batch, n_noise):
# 损失函数及优化器
criterion = nn.BCELoss()
optimizer_g = torch.optim.Adam(generator.parameters(), lr=0.001)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=0.001)
# 训练过程中保存损失值
gen_loss_hist = []
dis_loss_hist = []
# 训练数据准备
x_real, _ = get_real_samples(n_batch)
x_real = torch.from_numpy(x_real).float()
y_real = torch.ones(n_batch, 1)
x_noise = torch.randn(n_batch, n_noise)
# GAN模型训练
for epoch in range(n_epochs):
# 更新判别器
y_fake = discriminator(generator(x_noise))
y_real_pred = discriminator(x_real.unsqueeze(1))
dis_loss_real = criterion(y_real_pred, y_real)
dis_loss_fake = criterion(y_fake, torch.zeros(n_batch, 1))
dis_loss = dis_loss_real + dis_loss_fake
dis_loss.backward()
optimizer_d.step()
discriminator.zero_grad()
# 更新生成器
y_fake = discriminator(generator(x_noise))
gen_loss = criterion(y_fake, y_real)
gen_loss.backward()
optimizer_g.step()
generator.zero_grad()
# 保存损失值
gen_loss_hist.append(gen_loss.item())
dis_loss_hist.append(dis_loss.item())
# 绘制损失值曲线
plt.plot(gen_loss_hist, label='Generator Loss')
plt.plot(dis_loss_hist, label='Discriminator Loss')
plt.legend()
# 模型训练
n_samples = 1000
n_noise = 10
n_hidden = 32
n_epochs = 2000
n_batch = 64
generator = Generator(n_noise, n_hidden, 1)
discriminator = Discriminator(1, n_hidden, 1)
gan = nn.Sequential(generator, discriminator)
train_gan(generator, discriminator, gan, n_epochs, n_batch, n_noise)
# 生成样本并绘制结果
x = np.linspace(-1, 1, n_samples)
x = torch.from_numpy(x).float()
y = generator(torch.randn(n_samples, n_noise)).squeeze().detach().numpy()
plt.figure()
plt.scatter(x, y, s=1)
plt.show()
```
该代码通过定义生成器和判别器来实现GAN模型,生成器用于生成数据,判别器用于判断真实数据和生成数据。训练过程中先更新判别器,然后再更新生成器。本例使用了一个简单的二次函数作为真实数据的生成函数,使用Adam优化器和BCELoss损失函数进行训练。最终生成的数据可以在代码的最后绘制出来。
阅读全文