写一段GAN代码,用于解决结构化数据的样本不均衡问题
时间: 2023-04-02 17:01:39 浏览: 134
以下是一段GAN代码,用于解决结构化数据的样本不均衡问题:
```
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, LeakyReLU, BatchNormalization
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
import numpy as np
# 定义生成器
def build_generator(latent_dim, output_dim):
input_layer = Input(shape=(latent_dim,))
x = Dense(128)(input_layer)
x = LeakyReLU(alpha=.2)(x)
x = BatchNormalization(momentum=.8)(x)
x = Dense(256)(x)
x = LeakyReLU(alpha=.2)(x)
x = BatchNormalization(momentum=.8)(x)
x = Dense(512)(x)
x = LeakyReLU(alpha=.2)(x)
x = BatchNormalization(momentum=.8)(x)
output_layer = Dense(output_dim, activation='sigmoid')(x)
model = Model(input_layer, output_layer)
return model
# 定义判别器
def build_discriminator(input_dim):
input_layer = Input(shape=(input_dim,))
x = Dense(512)(input_layer)
x = LeakyReLU(alpha=.2)(x)
x = Dense(256)(x)
x = LeakyReLU(alpha=.2)(x)
x = Dense(1, activation='sigmoid')(x)
model = Model(input_layer, x)
return model
# 定义GAN模型
def build_gan(generator, discriminator):
discriminator.trainable = False
gan_input = Input(shape=(latent_dim,))
gan_output = discriminator(generator(gan_input))
gan = Model(gan_input, gan_output)
return gan
# 加载数据
data = np.load('data.npy')
labels = np.load('labels.npy')
# 定义超参数
latent_dim = 100
epochs = 10000
batch_size = 32
# 初始化模型
generator = build_generator(latent_dim, output_dim=data.shape[1])
discriminator = build_discriminator(input_dim=data.shape[1])
gan = build_gan(generator, discriminator)
# 编译模型
generator.compile(loss='binary_crossentropy', optimizer=Adam(.0002, .5))
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(.0002, .5))
gan.compile(loss='binary_crossentropy', optimizer=Adam(.0002, .5))
# 训练模型
for epoch in range(epochs):
# 训练判别器
idx = np.random.randint(, data.shape[], batch_size)
real_data = data[idx]
real_labels = labels[idx]
noise = np.random.normal(, 1, (batch_size, latent_dim))
fake_data = generator.predict(noise)
fake_labels = np.zeros((batch_size, 1))
real_loss = discriminator.train_on_batch(real_data, real_labels)
fake_loss = discriminator.train_on_batch(fake_data, fake_labels)
d_loss = .5 * np.add(real_loss, fake_loss)
# 训练生成器
noise = np.random.normal(, 1, (batch_size, latent_dim))
valid_labels = np.ones((batch_size, 1))
g_loss = gan.train_on_batch(noise, valid_labels)
# 打印损失
print(f'Epoch: {epoch}, D Loss: {d_loss}, G Loss: {g_loss}')
# 生成样本
noise = np.random.normal(, 1, (100, latent_dim))
generated_data = generator.predict(noise)
```
阅读全文