gan算法python代码
时间: 2023-10-02 12:14:13 浏览: 120
以下是一个简单的GAN算法的Python代码示例:
```python
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# 生成器网络模型
def generator(z, output_dim, n_units=128, reuse=False, alpha=0.01):
with tf.variable_scope('generator', reuse=reuse):
# 隐藏层
h1 = tf.layers.dense(z, n_units, activation=None)
# Leaky ReLU激活函数
h1 = tf.maximum(alpha * h1, h1)
# 输出层
logits = tf.layers.dense(h1, output_dim, activation=None)
out = tf.tanh(logits)
return out
# 判别器网络模型
def discriminator(x, n_units=128, reuse=False, alpha=0.01):
with tf.variable_scope('discriminator', reuse=reuse):
# 隐藏层
h1 = tf.layers.dense(x, n_units, activation=None)
# Leaky ReLU激活函数
h1 = tf.maximum(alpha * h1, h1)
# 输出层
logits = tf.layers.dense(h1, 1, activation=None)
out = tf.sigmoid(logits)
return out, logits
# 定义输入变量
input_dim = 100
output_dim = 28*28
tf.reset_default_graph()
X = tf.placeholder(tf.float32, shape=[None, output_dim], name='real_input')
Z = tf.placeholder(tf.float32, shape=[None, input_dim], name='input_noise')
# 定义超参数
g_units = 128
d_units = 128
alpha = 0.01
learning_rate = 0.001
smooth = 0.1
# 定义生成器
G = generator(Z, output_dim, n_units=g_units, alpha=alpha)
# 定义判别器
D_output_real, D_logits_real = discriminator(X, n_units=d_units, alpha=alpha)
D_output_fake, D_logits_fake = discriminator(G, n_units=d_units, reuse=True, alpha=alpha)
# 定义损失函数
d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logits_real, labels=tf.ones_like(D_output_real) * (1 - smooth)))
d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logits_fake, labels=tf.zeros_like(D_output_fake)))
d_loss = d_loss_real + d_loss_fake
g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=D_logits_fake, labels=tf.ones_like(D_output_fake)))
# 定义优化器
train_vars = tf.trainable_variables()
d_vars = [var for var in train_vars if var.name.startswith('discriminator')]
g_vars = [var for var in train_vars if var.name.startswith('generator')]
d_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(d_loss, var_list=d_vars)
g_train_opt = tf.train.AdamOptimizer(learning_rate).minimize(g_loss, var_list=g_vars)
# 加载MNIST数据集
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/")
# 训练模型
batch_size = 100
epochs = 100
samples = []
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
for e in range(epochs):
for i in range(mnist.train.num_examples // batch_size):
batch = mnist.train.next_batch(batch_size)
batch_images = batch[0].reshape((batch_size, output_dim))
batch_images = batch_images * 2 - 1
batch_noise = np.random.uniform(-1, 1, size=(batch_size, input_dim))
_ = sess.run(d_train_opt, feed_dict={X: batch_images, Z: batch_noise})
_ = sess.run(g_train_opt, feed_dict={Z: batch_noise})
# 每个epoch结束后输出损失值
train_loss_d = sess.run(d_loss, {Z: batch_noise, X: batch_images})
train_loss_g = g_loss.eval({Z: batch_noise})
print("Epoch {}/{}...".format(e+1, epochs),
"Discriminator Loss: {:.4f}...".format(train_loss_d),
"Generator Loss: {:.4f}".format(train_loss_g))
# 保存样本
sample_noise = np.random.uniform(-1, 1, size=(16, input_dim))
gen_samples = sess.run(generator(Z, output_dim, n_units=g_units, reuse=True, alpha=alpha), feed_dict={Z: sample_noise})
samples.append(gen_samples)
# 显示生成的图像
fig, axes = plt.subplots(figsize=(7,7), nrows=4, ncols=4, sharey=True, sharex=True)
for img, ax in zip(samples[-1], axes.flatten()):
ax.imshow(img.reshape((28,28)), cmap='Greys_r')
ax.xaxis.set_visible(False)
ax.yaxis.set_visible(False)
plt.show()
```
以上代码使用TensorFlow实现了一个简单的GAN模型,用于生成MNIST数据集中的手写数字图片。在训练过程中,我们通过反向传播优化生成器和判别器的参数,最终生成了一组手写数字图片。
阅读全文