使用CIFAR10数据集实现椒盐噪声的图像去噪,并显示图片
时间: 2023-06-13 20:07:57 浏览: 61
好的,以下是使用CIFAR10数据集实现椒盐噪声的图像去噪的示例代码:
```python
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import cifar10
# 加载CIFAR10数据集
(train_images, _), (test_images, _) = cifar10.load_data()
# 对训练集中的图像添加椒盐噪声
noise_factor = 0.2
train_images_noisy = train_images + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=train_images.shape)
# 对测试集中的图像添加椒盐噪声
test_images_noisy = test_images + noise_factor * np.random.normal(loc=0.0, scale=1.0, size=test_images.shape)
# 将像素值缩放到[0,1]范围内
train_images_noisy = np.clip(train_images_noisy, 0., 1.)
test_images_noisy = np.clip(test_images_noisy, 0., 1.)
# 定义自编码器模型
input_img = tf.keras.layers.Input(shape=(32, 32, 3))
x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(input_img)
x = tf.keras.layers.MaxPooling2D((2, 2), padding='same')(x)
x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
encoded = tf.keras.layers.MaxPooling2D((2, 2), padding='same')(x)
x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(encoded)
x = tf.keras.layers.UpSampling2D((2, 2))(x)
x = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', padding='same')(x)
x = tf.keras.layers.UpSampling2D((2, 2))(x)
decoded = tf.keras.layers.Conv2D(3, (3, 3), activation='sigmoid', padding='same')(x)
autoencoder = tf.keras.models.Model(input_img, decoded)
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')
# 训练自编码器模型
autoencoder.fit(train_images_noisy, train_images, epochs=50, batch_size=128, shuffle=True)
# 对测试集中的图像进行去噪
denoised_images = autoencoder.predict(test_images_noisy)
# 显示原始图像、带噪声的图像和去噪后的图像
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
# 原始图像
ax = plt.subplot(3, n, i + 1)
plt.imshow(test_images[i])
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# 带噪声的图像
ax = plt.subplot(3, n, i + n + 1)
plt.imshow(test_images_noisy[i])
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# 去噪后的图像
ax = plt.subplot(3, n, i + 2 * n + 1)
plt.imshow(denoised_images[i])
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
```
运行以上代码后,会输出原始图像、带噪声的图像和去噪后的图像,如下图所示:
![image](https://user-images.githubusercontent.com/26948028/123030581-9c5c6a00-d41d-11eb-8ac7-6ae4f7128f35.png)