CGAN代码
时间: 2023-07-03 15:18:22 浏览: 174
以下是一个简单的CGAN的Python代码示例,用于生成手写数字图像。假设我们想要在MNIST数据集上生成数字4的图像:
```python
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Reshape, Conv2DTranspose, Conv2D, LeakyReLU, Dropout, Input, Embedding, Concatenate
from tensorflow.keras import Model
# 加载MNIST数据集
(x_train, y_train), (_, _) = tf.keras.datasets.mnist.load_data()
x_train = x_train[y_train == 4] # 选择数字4的图像
x_train = x_train.reshape(x_train.shape[0], 28, 28, 1).astype('float32')
x_train = (x_train - 127.5) / 127.5 # 将像素值缩放到[-1, 1]之间
BUFFER_SIZE = 60000
BATCH_SIZE = 256
EPOCHS = 50
NUM_CLS = 10 # 数字的类别数
NOISE_DIM = 100
# 用于生成假图像的生成器网络
def make_generator_model():
model = tf.keras.Sequential()
model.add(Dense(7*7*256, use_bias=False, input_shape=(NOISE_DIM+NUM_CLS,)))
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Reshape((7, 7, 256)))
assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
model.add(Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
assert model.output_shape == (None, 7, 7, 128)
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
assert model.output_shape == (None, 14, 14, 64)
model.add(BatchNormalization())
model.add(LeakyReLU())
model.add(Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
assert model.output_shape == (None, 28, 28, 1)
return model
# 用于判别真假图像的判别器网络
def make_discriminator_model():
model = tf.keras.Sequential()
model.add(Conv2D(64, (5, 5), strides=(2, 2), padding='same',
input_shape=[28, 28, 1]))
model.add(LeakyReLU())
model.add(Dropout(0.3))
model.add(Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
model.add(LeakyReLU())
model.add(Dropout(0.3))
model.add(Flatten())
model.add(Dense(1))
return model
# 将类别信息转换成one-hot编码
def make_one_hot(labels):
one_hot = tf.one_hot(labels, NUM_CLS)
return tf.reshape(one_hot, (len(labels), NUM_CLS))
# 定义损失函数
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)
categorical_crossentropy = tf.keras.losses.CategoricalCrossentropy(from_logits=True)
# 定义生成器的损失函数
def generator_loss(fake_output, labels):
return cross_entropy(tf.ones_like(fake_output), fake_output) + categorical_crossentropy(labels, fake_output)
# 定义判别器的损失函数
def discriminator_loss(real_output, fake_output):
real_loss = cross_entropy(tf.ones_like(real_output), real_output)
fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
total_loss = real_loss + fake_loss
return total_loss
# 定义优化器
generator_optimizer = tf.keras.optimizers.Adam(1e-4)
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)
# 定义生成器和判别器
generator = make_generator_model()
discriminator = make_discriminator_model()
# 定义训练步骤
@tf.function
def train_step(images, labels):
noise = tf.random.normal([BATCH_SIZE, NOISE_DIM])
gen_input = tf.concat([noise, labels], axis=1)
with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
generated_images = generator(gen_input, training=True)
real_output = discriminator(images, training=True)
fake_output = discriminator(generated_images, training=True)
gen_loss = generator_loss(fake_output, labels)
disc_loss = discriminator_loss(real_output, fake_output)
gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
# 开始训练
for epoch in range(EPOCHS):
for i in range(x_train.shape[0] // BATCH_SIZE):
images = x_train[i * BATCH_SIZE:(i + 1) * BATCH_SIZE]
labels = make_one_hot(np.ones(BATCH_SIZE) * 4) # 生成数字4的图像
train_step(images, labels)
# 每个epoch结束后生成一张图像
noise = tf.random.normal([1, NOISE_DIM])
label = make_one_hot(np.array([4]))
gen_input = tf.concat([noise, label], axis=1)
generated_image = generator(gen_input, training=False)
generated_image = generated_image * 127.5 + 127.5 # 将像素值还原到[0, 255]之间
generated_image = generated_image.numpy().squeeze().astype('uint8')
plt.imshow(generated_image, cmap='gray')
plt.axis('off')
plt.show()
```
这段代码中使用了MNIST数据集,通过生成器生成数字4的图像。在训练过程中,我们将生成器和判别器进行交替训练,并通过优化器更新网络参数。
阅读全文