GAN 代码
时间: 2023-07-02 10:08:12 浏览: 54
以下是一个简单的 GAN 代码示例,用于生成手写数字图像:
```python
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Dense, Flatten, Reshape, Input, Dropout, LeakyReLU
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
from tqdm import tqdm
# 加载 MNIST 数据集
(X_train, _), (_, _) = mnist.load_data()
# 将像素值缩放到 -1 到 1 之间
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
# 定义生成器模型
generator = Sequential([
Dense(256, input_shape=(100,)),
LeakyReLU(alpha=0.2),
Dense(512),
LeakyReLU(alpha=0.2),
Dense(1024),
LeakyReLU(alpha=0.2),
Dense(28 * 28 * 1, activation='tanh'),
Reshape((28, 28, 1))
])
# 定义判别器模型
discriminator = Sequential([
Flatten(input_shape=(28, 28, 1)),
Dense(512),
LeakyReLU(alpha=0.2),
Dropout(0.3),
Dense(256),
LeakyReLU(alpha=0.2),
Dropout(0.3),
Dense(1, activation='sigmoid')
])
# 编译判别器模型
discriminator.compile(optimizer=Adam(lr=0.0002, beta_1=0.5), loss='binary_crossentropy', metrics=['accuracy'])
# 将判别器设置为不可训练
discriminator.trainable = False
# 定义 GAN 模型
gan_input = Input(shape=(100,))
gan_output = discriminator(generator(gan_input))
gan = Model(gan_input, gan_output)
# 编译 GAN 模型
gan.compile(optimizer=Adam(lr=0.0002, beta_1=0.5), loss='binary_crossentropy')
# 训练 GAN 模型
epochs = 10000
batch_size = 128
for epoch in range(epochs):
# 选择一个随机的样本批次
idx = np.random.randint(0, X_train.shape[0], batch_size)
real_images = X_train[idx]
# 生成一批假图像
noise = np.random.normal(0, 1, (batch_size, 100))
fake_images = generator.predict(noise)
# 训练判别器模型
discriminator_loss_real = discriminator.train_on_batch(real_images, np.ones((batch_size, 1)))
discriminator_loss_fake = discriminator.train_on_batch(fake_images, np.zeros((batch_size, 1)))
discriminator_loss = 0.5 * np.add(discriminator_loss_real, discriminator_loss_fake)
# 训练生成器模型
noise = np.random.normal(0, 1, (batch_size, 100))
generator_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))
# 输出训练过程
if epoch % 100 == 0:
print(f'Epoch: {epoch}, Discriminator Loss: {discriminator_loss}, Generator Loss: {generator_loss}')
# 生成一些手写数字图像
noise = np.random.normal(0, 1, (10, 100))
generated_images = generator.predict(noise)
generated_images = 0.5 * generated_images + 0.5
fig, axs = plt.subplots(1, 10, figsize=(20, 2))
cnt = 0
for i in range(10):
axs[i].imshow(generated_images[cnt, :, :, 0], cmap='gray')
axs[i].axis('off')
cnt += 1
plt.show()
```
在该代码中,我们使用了 Keras 框架来定义生成器和判别器模型,并将它们组合成一个 GAN 模型。我们使用 MNIST 数据集作为示例数据集,并在训练过程中生成一些手写数字图像以进行可视化。在训练过程中,我们首先训练判别器模型来区分真实图像和假图像,然后训练生成器模型来生成更逼真的假图像。最终,我们可以得到一个可以生成手写数字图像的 GAN 模型。