用python写一个GAN网络
时间: 2024-05-04 19:18:11 浏览: 130
好,下面是一个简单的 GAN 网络的 Python 代码:
```python
import tensorflow as tf
from tensorflow.keras.layers import Dense, Reshape, Flatten
from tensorflow.keras.layers import Conv2D, Conv2DTranspose
from tensorflow.keras.layers import BatchNormalization, LeakyReLU
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt
# 定义生成器模型
def build_generator(z_dim):
model = Sequential()
model.add(Dense(256, input_dim=z_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(28*28*1, activation='tanh'))
model.add(Reshape((28, 28, 1)))
return model
# 定义判别器模型
def build_discriminator(img_shape):
model = Sequential()
model.add(Conv2D(32, kernel_size=3, strides=2, input_shape=img_shape, padding='same'))
model.add(LeakyReLU(alpha=0.2))
model.add(Conv2D(64, kernel_size=3, strides=2, padding='same'))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Conv2D(128, kernel_size=3, strides=2, padding='same'))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Conv2D(256, kernel_size=3, strides=1, padding='same'))
model.add(BatchNormalization(momentum=0.8))
model.add(LeakyReLU(alpha=0.2))
model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
return model
# 定义 GAN 网络
def build_gan(generator, discriminator):
model = Sequential()
model.add(generator)
model.add(discriminator)
return model
# 加载 MNIST 数据集
(X_train, _), (_, _) = tf.keras.datasets.mnist.load_data()
# 数据归一化
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)
# 定义 GAN 网络的超参数
z_dim = 100
img_shape = (28, 28, 1)
# 构建生成器、判别器和 GAN 网络
generator = build_generator(z_dim)
discriminator = build_discriminator(img_shape)
gan = build_gan(generator, discriminator)
# 编译判别器和 GAN 网络
discriminator.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5), metrics=['accuracy'])
gan.compile(loss='binary_crossentropy', optimizer=Adam(0.0002, 0.5))
# 定义训练过程
def train(X_train, epochs=20000, batch_size=128, save_interval=1000):
# 初始化 loss 和 acc
d_loss = []
d_acc = []
g_loss = []
# 生成噪声
noise = np.random.normal(0, 1, (batch_size, z_dim))
# 训练 GAN 网络
for epoch in range(epochs):
# 训练判别器
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
noise = np.random.normal(0, 1, (batch_size, z_dim))
gen_imgs = generator.predict(noise)
d_loss_real = discriminator.train_on_batch(imgs, np.ones((batch_size, 1)))
d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((batch_size, 1)))
d_loss_epoch = 0.5 * np.add(d_loss_real, d_loss_fake)
d_loss.append(d_loss_epoch[0])
d_acc.append(d_loss_epoch[1])
# 训练生成器
noise = np.random.normal(0, 1, (batch_size, z_dim))
g_loss_epoch = gan.train_on_batch(noise, np.ones((batch_size, 1)))
g_loss.append(g_loss_epoch)
# 打印训练进度
print(f'epoch: {epoch}/{epochs}, d_loss: {d_loss_epoch[0]}, d_acc: {100*d_loss_epoch[1]}%, g_loss: {g_loss_epoch}')
# 保存生成器的输出
if epoch % save_interval == 0:
save_imgs(generator, epoch)
# 定义保存生成器输出的函数
def save_imgs(generator, epoch):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, z_dim))
gen_imgs = generator.predict(noise)
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig(f'gan_mnist_{epoch}.png')
plt.close()
# 训练 GAN 网络
train(X_train)
```
这个 GAN 网络是为 MNIST 手写数字数据集而设计的,其生成器模型包含了几个全连接层,而判别器模型则包含了几个卷积层和全连接层。在训练过程中,先训练判别器,再训练生成器。GAN 网络的训练过程一般比较复杂,需要耐心调试。
阅读全文