tensorflow2中基于dcgan
时间: 2024-02-04 10:02:33 浏览: 93
tensorflow实现的gan
DCGAN是一种基于生成对抗网络(GAN)的图像生成模型。在TensorFlow 2中,可以使用Keras API实现DCGAN。
以下是一个简单的DCGAN现的示例代码:
```python
import tensorflow as tf
from tensorflow.keras.layers import Dense, Reshape, Conv2DTranspose, Conv2D, BatchNormalization, LeakyReLU, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.optimizers import Adam
import numpy as np
import matplotlib.pyplot as plt
# 定义生成器模型
def make_generator_model():
model = Sequential()
model.add(Dense(7*7*256, use_bias=False, input_shape=(100,)))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Reshape((7, 7, 256)))
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
model.add(Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
assert model.output_shape == (None, 7, 7, 128)
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert model.output_shape == (None, 14, 14, 64)
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
assert model.output_shape == (None, 28, 28, 1)
return model
# 定义判别器模型
def make_discriminator_model():
model = Sequential()
model.add(Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[28, 28, 1]))
model.add(LeakyReLU())
model.add(Dropout(0.3))
model.add(Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
model.add(LeakyReLU())
model.add(Dropout(0.3))
model.add(Flatten())
model.add(Dense(1))
return model
# 定义损失函数
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
def generator_loss(fake_output):
return cross_entropy(tf.ones_like(fake_output), fake_output)
# 定义优化器
generator_optimizer = Adam(1e-4)
discriminator_optimizer = Adam(1e-4)
# 训练函数
@tf.function
def train_step(images):
noise = tf.random.normal([BATCH_SIZE, LATENT_DIM])
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
# 生成器
generator = make_generator_model()
# 判别器
discriminator = make_discriminator_model()
# 数据集
(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # 将图片归一化到 [-1, 1] 的区间内
BUFFER_SIZE = 60000
BATCH_SIZE = 256
LATENT_DIM = 100
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
# 训练
EPOCHS = 100
noise_dim = 100
num_examples_to_generate = 16
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
seed = tf.random.normal([num_examples_to_generate, noise_dim])
for epoch in range(EPOCHS):
for image_batch in train_dataset:
train_step(image_batch)
# 每经过一定周期后,保存生成器的输出
if epoch % 10 == 0:
generate_and_save_images(generator, epoch + 1, seed)
# 生成图片并保存
def generate_and_save_images(model, epoch, test_input):
# 生成图片
predictions = model(test_input, training=False)
fig = plt.figure(figsize=(4, 4))
for i in range(predictions.shape[0]):
plt.subplot(4, 4, i+1)
plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
plt.axis('off')
plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
plt.show()
# 最后一轮保存生成器的输出
generate_and_save_images(generator, EPOCHS, seed)
```
在这个代码示例中,我们使用了MNIST数据集来训练模型,训练过程中,我们使用了Adam优化器和二元交叉熵损失函数,并且在每次训练时,我们都会生成一些图片并保存下来,以便于我们查看生成器的输出。
阅读全文