gan生成一维数据代码
时间: 2023-09-03 22:11:02 浏览: 250
以下是使用GAN生成一维数据的Python代码示例:
```python
import numpy as np
import matplotlib.pyplot as plt
from keras.models import Sequential
from keras.layers import Dense, LeakyReLU
# 生成真实数据
def generate_real_data(n_samples):
x = np.random.randn(n_samples) * 5 + 10
y = np.ones(n_samples)
return x, y
# 生成隐变量
def generate_latent_variable(n_samples, latent_dim):
x = np.random.randn(n_samples, latent_dim)
return x
# 生成假数据
def generate_fake_data(generator_model, n_samples, latent_dim):
x = generate_latent_variable(n_samples, latent_dim)
y = np.zeros(n_samples)
generated_data = generator_model.predict(x)
return generated_data.flatten(), y
# 定义生成器模型
def define_generator(latent_dim):
model = Sequential()
model.add(Dense(10, input_dim=latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='linear'))
return model
# 定义判别器模型
def define_discriminator():
model = Sequential()
model.add(Dense(10, input_dim=1))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
return model
# 定义GAN模型
def define_gan(generator_model, discriminator_model):
discriminator_model.trainable = False
model = Sequential()
model.add(generator_model)
model.add(discriminator_model)
model.compile(loss='binary_crossentropy', optimizer='adam')
return model
# 训练GAN模型
def train_gan(generator_model, discriminator_model, gan_model, n_epochs, n_batch, latent_dim):
for epoch in range(n_epochs):
for batch in range(n_batch):
# 更新判别器
x_real, y_real = generate_real_data(n_batch)
discriminator_model.train_on_batch(x_real, y_real)
x_fake, y_fake = generate_fake_data(generator_model, n_batch, latent_dim)
discriminator_model.train_on_batch(x_fake, y_fake)
# 更新生成器
x_gan = generate_latent_variable(n_batch, latent_dim)
y_gan = np.ones(n_batch)
gan_model.train_on_batch(x_gan, y_gan)
# 展示生成的数据
x_fake, y_fake = generate_fake_data(generator_model, 100, latent_dim)
plt.hist(x_fake, bins=20)
plt.show()
# 主函数
if __name__ == '__main__':
latent_dim = 5
discriminator_model = define_discriminator()
generator_model = define_generator(latent_dim)
gan_model = define_gan(generator_model, discriminator_model)
train_gan(generator_model, discriminator_model, gan_model, 5000, 128, latent_dim)
```
这段代码实现了一个简单的一维GAN模型,可以生成符合正态分布的一维数据。在训练过程中,每一轮训练结束后,会生成100个假数据并展示其分布情况。可以通过修改模型的结构和参数来生成不同的一维数据。
阅读全文