运用对抗网络生成图片并保存在指定文件夹内python代码
时间: 2023-12-31 11:04:12 浏览: 68
下面是一个简单的基于对抗网络生成图片并保存在指定文件夹内的Python代码:
```python
import os
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Dropout
from tensorflow.keras.layers import BatchNormalization, Activation, ZeroPadding2D
from tensorflow.keras.layers import LeakyReLU
from tensorflow.keras.layers import UpSampling2D, Conv2D
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.optimizers import Adam
# 定义生成器模型
def build_generator():
model = Sequential()
model.add(Dense(256, input_dim=100))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod((28, 28, 1)), activation='tanh'))
model.add(Reshape((28, 28, 1)))
noise = Input(shape=(100,))
img = model(noise)
return Model(noise, img)
# 定义判别器模型
def build_discriminator():
model = Sequential()
model.add(Flatten(input_shape=(28, 28, 1)))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(256))
model.add(LeakyReLU(alpha=0.2))
model.add(Dense(1, activation='sigmoid'))
img = Input(shape=(28, 28, 1))
validity = model(img)
return Model(img, validity)
# 加载MNIST数据集
(X_train, _), (_, _) = mnist.load_data()
# 归一化数据
X_train = (X_train.astype(np.float32) - 127.5) / 127.5
X_train = np.expand_dims(X_train, axis=3)
# 设置训练参数
epochs = 10000
batch_size = 32
save_interval = 1000
# 构建对抗网络
optimizer = Adam(0.0002, 0.5)
discriminator = build_discriminator()
discriminator.compile(loss='binary_crossentropy',
optimizer=optimizer,
metrics=['accuracy'])
generator = build_generator()
z = Input(shape=(100,))
img = generator(z)
discriminator.trainable = False
valid = discriminator(img)
combined = Model(z, valid)
combined.compile(loss='binary_crossentropy',
optimizer=optimizer)
# 开始训练
for epoch in range(epochs):
# 随机选择一批真实图片
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
# 生成一批噪声数据
noise = np.random.normal(0, 1, (batch_size, 100))
# 生成一批假的图片
gen_imgs = generator.predict(noise)
# 训练判别器
d_loss_real = discriminator.train_on_batch(imgs, np.ones((batch_size, 1)))
d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((batch_size, 1)))
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# 训练生成器
g_loss = combined.train_on_batch(noise, np.ones((batch_size, 1)))
# 输出损失值
print("Epoch %d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))
# 保存生成的图片
if epoch % save_interval == 0:
os.makedirs('images', exist_ok=True)
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, 100))
gen_imgs = generator.predict(noise)
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c)
cnt = 0
for i in range(r):
for j in range(c):
axs[i, j].imshow(gen_imgs[cnt, :, :, 0], cmap='gray')
axs[i, j].axis('off')
cnt += 1
fig.savefig("images/mnist_%d.png" % epoch)
plt.close()
```
这段代码实现了一个简单的基于对抗网络生成手写数字的图片,并将生成的图片保存在名为"images"的文件夹中。你可以根据自己的需求进行修改。
阅读全文