cgan python代码
时间: 2023-07-03 08:24:01 浏览: 185
CGAN源代码
4星 · 用户满意度95%
以下是使用Python实现CGAN的示例代码:
```python
import numpy as np
import matplotlib.pyplot as plt
from keras.datasets import mnist
from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply
from keras.layers import BatchNormalization, Activation, Embedding, ZeroPadding2D
from keras.layers import Conv2DTranspose, UpSampling2D, Conv2D
from keras.layers.advanced_activations import LeakyReLU
from keras.models import Sequential, Model
from keras.optimizers import Adam
from keras.utils import to_categorical
class CGAN():
def __init__(self):
self.img_rows = 28
self.img_cols = 28
self.channels = 1
self.img_shape = (self.img_rows, self.img_cols, self.channels)
self.num_classes = 10
self.latent_dim = 100
optimizer = Adam(0.0002, 0.5)
# 构建鉴别器模型
self.discriminator = self.build_discriminator()
self.discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
# 构建生成器模型
self.generator = self.build_generator()
# 生成器输入噪声和标签
z = Input(shape=(self.latent_dim,))
label = Input(shape=(1,))
label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))
# 将噪声和标签进行合并
model_input = multiply([z, label_embedding])
img = self.generator(model_input)
# 鉴别器不可训练
self.discriminator.trainable = False
# 鉴别器输入图像和标签
valid = self.discriminator([img, label])
# 建立复合模型
self.combined = Model([z, label], valid)
self.combined.compile(loss='binary_crossentropy', optimizer=optimizer)
def build_generator(self):
model = Sequential()
model.add(Dense(128 * 7 * 7, activation="relu", input_dim=self.latent_dim))
model.add(Reshape((7, 7, 128)))
model.add(BatchNormalization(momentum=0.8))
model.add(UpSampling2D())
model.add(Conv2D(128, kernel_size=3, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8))
model.add(UpSampling2D())
model.add(Conv2D(64, kernel_size=3, padding="same"))
model.add(Activation("relu"))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(self.channels, kernel_size=3, padding="same"))
model.add(Activation("tanh"))
model.summary()
noise = Input(shape=(self.latent_dim,))
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(self.num_classes, self.latent_dim)(label))
model_input = multiply([noise, label_embedding])
img = model(model_input)
return Model([noise, label], img)
def build_discriminator(self):
model = Sequential()
model.add(Conv2D(16, kernel_size=3, strides=2, input_shape=self.img_shape, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Conv2D(32, kernel_size=3, strides=2, padding="same"))
model.add(ZeroPadding2D(padding=((0,1),(0,1))))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(64, kernel_size=3, strides=2, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(BatchNormalization(momentum=0.8))
model.add(Conv2D(128, kernel_size=3, strides=1, padding="same"))
model.add(LeakyReLU(alpha=0.2))
model.add(Dropout(0.25))
model.add(Flatten())
model.add(Dense(1, activation='sigmoid'))
model.summary()
img = Input(shape=self.img_shape)
label = Input(shape=(1,), dtype='int32')
label_embedding = Flatten()(Embedding(self.num_classes, np.prod(self.img_shape))(label))
flat_img = Flatten()(img)
model_input = multiply([flat_img, label_embedding])
validity = model(model_input)
return Model([img, label], validity)
def train(self, epochs, batch_size=128, sample_interval=50):
# 加载数据集
(X_train, y_train), (_, _) = mnist.load_data()
# 标准化数据集
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)
y_train = y_train.reshape(-1, 1)
# 鉴别器标签
valid = np.ones((batch_size, 1))
fake = np.zeros((batch_size, 1))
for epoch in range(epochs):
# ---------------------
# 训练鉴别器模型
# ---------------------
# 随机选取一批真实图像
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs, labels = X_train[idx], y_train[idx]
# 生成一批假图像
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
gen_labels = np.random.randint(0, self.num_classes, batch_size).reshape(-1, 1)
gen_imgs = self.generator.predict([noise, gen_labels])
# 训练鉴别器模型
d_loss_real = self.discriminator.train_on_batch([imgs, labels], valid)
d_loss_fake = self.discriminator.train_on_batch([gen_imgs, gen_labels], fake)
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# ---------------------
# 训练生成器模型
# ---------------------
# 生成一批噪声和标签
noise = np.random.normal(0, 1, (batch_size, self.latent_dim))
gen_labels = np.random.randint(0, self.num_classes, batch_size).reshape(-1, 1)
# 训练生成器模型
g_loss = self.combined.train_on_batch([noise, gen_labels], valid)
# 打印训练进度
print ("%d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100*d_loss[1], g_loss))
# 每隔一定间隔生成一批假图像
if epoch % sample_interval == 0:
self.sample_images(epoch)
def sample_images(self, epoch):
r, c = 10, 10
noise = np.random.normal(0, 1, (r * c, self.latent_dim))
gen_labels = np.arange(0, 10).reshape(-1, 1)
gen_imgs = self.generator.predict([noise, gen_labels])
# 将图像像素值从[-1, 1]转换到[0,1]
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt,:,:,0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig("images/mnist_%d.png" % epoch)
plt.close()
if __name__ == '__main__':
cgan = CGAN()
cgan.train(epochs=20000, batch_size=32, sample_interval=200)
```
这里使用了Keras框架来实现CGAN模型,代码中包含了生成器和鉴别器的构建函数,以及训练函数和生成图像的函数。在训练过程中,会随机选择一批真实图像和生成一批假图像,并分别用于训练鉴别器和生成器模型。同时,我们也可以通过指定间隔来生成假图像,并保存到本地进行观察。
阅读全文