基于MNIST,使用Conditional-GAN框架生成手写数字,给出完整代码
时间: 2024-11-30 15:23:14 浏览: 21
conditional-DCGAN-for-MNIST:Tensorflow中的条件DCGAN,用于从MNIST数据集中生成手写数字
基于MNIST数据集生成手写数字的Conditional GAN (条件生成对抗网络) 实现通常需要Python语言和相关的深度学习库,如TensorFlow或PyTorch。这里我将提供一个简化的使用TensorFlow实现的例子:
```python
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.models import Model
from tensorflow.keras.datasets import mnist
from tensorflow.keras.optimizers import Adam
# 加载MNIST数据集
(x_train, _), (_, _) = mnist.load_data()
x_train = x_train / 255.
# 数据预处理,添加通道维度
x_train = x_train.reshape(-1, 28, 28, 1)
# 定义生成器模型
def build_generator(latent_dim):
model = tf.keras.Sequential([
Input(shape=(latent_dim,)),
Dense(7*7*64, activation='relu'),
Reshape((7, 7, 64)),
Conv2DTranspose(64, (4,4), strides=2, padding='same', activation='relu'),
Conv2DTranspose(32, (4,4), strides=2, padding='same', activation='relu'),
Conv2D(1, (3,3), activation='sigmoid')
])
return model
# 定义判别器模型
def build_discriminator(img_shape):
model = tf.keras.Sequential([
Input(img_shape),
Conv2D(64, (3,3), strides=2, padding='same', input_shape=img_shape, activation='relu'),
Conv2D(128, (3,3), strides=2, padding='same', activation='relu'),
Flatten(),
Dense(1, activation='sigmoid')
])
return model
latent_dim = 100
generator = build_generator(latent_dim)
discriminator = build_discriminator((28, 28, 1))
# 编译GAN
optimizer_g = Adam(0.0002, 0.5)
optimizer_d = Adam(0.0002, 0.5)
loss_fn = tf.keras.losses.BinaryCrossentropy(from_logits=True)
@tf.function
def train_step(real_images, noise):
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(noise, training=True)
real_output = discriminator(real_images, training=True)
fake_output = discriminator(generated_images, training=True)
# 计算损失
gen_loss = loss_fn(tf.ones_like(fake_output), fake_output)
disc_loss_real = loss_fn(tf.ones_like(real_output), real_output)
disc_loss_fake = loss_fn(tf.zeros_like(fake_output), fake_output)
disc_loss = disc_loss_real + disc_loss_fake
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
optimizer_g.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
optimizer_d.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
# 开始训练
num_epochs = 50
for epoch in range(num_epochs):
for i in range(int(60000 / batch_size)):
real_images_batch = x_train[i * batch_size:(i + 1) * batch_size]
noise = tf.random.normal([batch_size, latent_dim])
train_step(real_images_batch, noise)
# 生成样本展示
generated_images = generator.predict(tf.random.normal([16, latent_dim]))
# 显示生成的手写数字图像
plt.imshow(generated_images[0].reshape(28, 28), cmap='gray')
plt.show()
阅读全文