基于GAN的图像压缩代码
时间: 2023-11-14 20:07:19 浏览: 35
以下是基于GAN的图像压缩代码的示例:
```python
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Reshape, Flatten, Conv2D, Conv2DTranspose
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist
import numpy as np
# 定义图像大小和压缩因子
img_rows = 28
img_cols = 28
channels = 1
img_shape = (img_rows, img_cols, channels)
latent_dim = 100
# 构建生成器
def build_generator():
noise = Input(shape=(latent_dim,))
x = Dense(128 * 7 * 7, activation="relu")(noise)
x = Reshape((7, 7, 128))(x)
x = Conv2DTranspose(128, kernel_size=3, strides=2, padding="same")(x)
x = Conv2DTranspose(64, kernel_size=3, strides=1, padding="same")(x)
x = Conv2DTranspose(channels, kernel_size=3, strides=2, padding="same", activation="tanh")(x)
generator = Model(noise, x)
return generator
# 构建判别器
def build_discriminator():
img = Input(shape=img_shape)
x = Conv2D(32, kernel_size=3, strides=2, padding="same")(img)
x = Conv2D(64, kernel_size=3, strides=2, padding="same")(x)
x = Conv2D(128, kernel_size=3, strides=2, padding="same")(x)
x = Flatten()(x)
x = Dense(1, activation="sigmoid")(x)
discriminator = Model(img, x)
discriminator.compile(loss="binary_crossentropy", optimizer=Adam(), metrics=["accuracy"])
return discriminator
# 构建GAN模型
def build_gan(generator, discriminator):
discriminator.trainable = False
noise = Input(shape=(latent_dim,))
img = generator(noise)
valid = discriminator(img)
gan = Model(noise, valid)
gan.compile(loss="binary_crossentropy", optimizer=Adam())
return gan
# 训练GAN模型
def train(epochs, batch_size=128, save_interval=50):
# 加载数据集
(X_train, _), (_, _) = mnist.load_data()
# 归一化数据
X_train = X_train / 127.5 - 1.
X_train = np.expand_dims(X_train, axis=3)
# 初始化生成器和判别器
generator = build_generator()
discriminator = build_discriminator()
gan = build_gan(generator, discriminator)
# 开始训练
for epoch in range(epochs):
# 训练判别器
idx = np.random.randint(0, X_train.shape[0], batch_size)
imgs = X_train[idx]
noise = np.random.normal(0, 1, (batch_size, latent_dim))
gen_imgs = generator.predict(noise)
d_loss_real = discriminator.train_on_batch(imgs, np.ones((batch_size, 1)))
d_loss_fake = discriminator.train_on_batch(gen_imgs, np.zeros((batch_size, 1)))
d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)
# 训练生成器
noise = np.random.normal(0, 1, (batch_size, latent_dim))
g_loss = gan.train_on_batch(noise, np.ones((batch_size, 1)))
# 每个epoch保存一次生成的图像
if epoch % save_interval == 0:
print("Epoch %d [D loss: %f, acc.: %.2f%%] [G loss: %f]" % (epoch, d_loss[0], 100 * d_loss[1], g_loss))
save_images(epoch, generator)
# 生成图像并保存
def save_images(epoch, generator):
r, c = 5, 5
noise = np.random.normal(0, 1, (r * c, latent_dim))
gen_imgs = generator.predict(noise)
gen_imgs = 0.5 * gen_imgs + 0.5
fig, axs = plt.subplots(r, c, figsize=(10,10))
cnt = 0
for i in range(r):
for j in range(c):
axs[i,j].imshow(gen_imgs[cnt, :,:,0], cmap='gray')
axs[i,j].axis('off')
cnt += 1
fig.savefig("images/mnist_%d.png" % epoch)
plt.close()
# 训练模型
train(epochs=20000, batch_size=32, save_interval=100)
```
上述代码使用了MNIST数据集进行训练,并生成了一系列压缩后的图像。你可以根据自己的需求修改代码以适应不同的数据集和压缩比例。