堆叠自编码器python代码
时间: 2023-10-28 09:00:03 浏览: 105
下面是一个堆叠自编码器(Stacked Autoencoder)的简单示例代码,使用Python编写并使用TensorFlow库:
```python
import tensorflow as tf
import numpy as np
class StackedAutoencoder:
def __init__(self, layers):
self.layers = layers
self.weights = []
self.biases = []
def build(self):
num_layers = len(self.layers)
for i in range(num_layers - 1):
input_dim = self.layers[i]
output_dim = self.layers[i+1]
# 随机初始化权重和偏置
weight = tf.Variable(tf.random.normal(shape=(input_dim, output_dim)))
bias = tf.Variable(tf.random.normal(shape=(output_dim,)))
self.weights.append(weight)
self.biases.append(bias)
def encode(self, x):
encoded = x
for weight, bias in zip(self.weights, self.biases):
encoded = tf.nn.sigmoid(tf.matmul(encoded, weight) + bias)
return encoded
def decode(self, encoded):
decoded = encoded
for weight, bias in zip(reversed(self.weights), reversed(self.biases)):
decoded = tf.nn.sigmoid(tf.matmul(decoded, tf.transpose(weight)) + bias)
return decoded
# 示例用法
# 定义网络结构
layers = [784, 256, 64, 256, 784]
# 构建自编码器
autoencoder = StackedAutoencoder(layers)
autoencoder.build()
# 使用MNIST数据集作为例子
(x_train, _), (x_test, _) = tf.keras.datasets.mnist.load_data()
# 数据预处理
x_train = x_train / 255.0
x_test = x_test / 255.0
# 将图像转换为向量形式
x_train = x_train.reshape((-1, 784))
x_test = x_test.reshape((-1, 784))
# 定义训练参数
epochs = 10
batch_size = 128
learning_rate = 0.001
# 定义优化器和损失函数
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate)
mse_loss = tf.keras.losses.MeanSquaredError()
# 训练自编码器
for epoch in range(epochs):
num_batches = x_train.shape[0] // batch_size
for batch in range(num_batches):
start = batch * batch_size
end = (batch + 1) * batch_size
# 前向传播
with tf.GradientTape() as tape:
encoded = autoencoder.encode(x_train[start:end])
decoded = autoencoder.decode(encoded)
# 计算重构误差
loss = mse_loss(x_train[start:end], decoded)
# 反向传播
gradients = tape.gradient(loss, autoencoder.weights + autoencoder.biases)
optimizer.apply_gradients(zip(gradients, autoencoder.weights + autoencoder.biases))
print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.numpy()}")
# 测试自编码器
encoded_images = autoencoder.encode(x_test)
decoded_images = autoencoder.decode(encoded_images)
# 显示原始图像和重构图像
import matplotlib.pyplot as plt
n = 10 # 只显示前10个图像
plt.figure(figsize=(20, 4))
for i in range(n):
# 原始图像
ax = plt.subplot(2, n, i+1)
plt.imshow(x_test[i].reshape(28, 28))
plt.title("Original")
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
# 重构图像
ax = plt.subplot(2, n, i+1+n)
plt.imshow(decoded_images[i].numpy().reshape(28, 28))
plt.title("Reconstructed")
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
```
这段代码演示了如何使用堆叠自编码器进行无监督学习,以MNIST数据集为例。它包括定义网络结构、构建自编码器、训练自编码器和测试自编码器等步骤。最后,它还展示了原始图像和重构图像的对比。你可以根据需要进行修改和扩展。
阅读全文