自编码器pytorch
时间: 2023-11-04 21:55:52 浏览: 88
gae-pytorch-master_pytorch_pytorchgae_GAE_自编码器_gaepytorchmaster_
5星 · 资源好评率100%
自编码器是一种无监督学习算法,可以用于数据压缩、特征提取、数据去噪等任务。PyTorch是一个流行的深度学习框架,提供了丰富的工具和库,可以方便地实现自编码器。
下面是一个简单的PyTorch自编码器实现示例:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
# 定义自编码器类
class AutoEncoder(nn.Module):
def __init__(self):
super(AutoEncoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 12),
nn.ReLU(),
nn.Linear(12, 3)
)
self.decoder = nn.Sequential(
nn.Linear(3, 12),
nn.ReLU(),
nn.Linear(12, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 784),
nn.Sigmoid()
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
# 加载数据集
train_data = torch.utils.data.DataLoader(
torchvision.datasets.MNIST(
root='./mnist/',
train=True,
transform=torchvision.transforms.ToTensor(),
download=True
),
batch_size=64,
shuffle=True
)
# 实例化自编码器
autoencoder = AutoEncoder()
# 定义优化器和损失函数
optimizer = optim.Adam(autoencoder.parameters(), lr=0.001)
loss_func = nn.MSELoss()
# 训练自编码器
for epoch in range(10):
for step, (x, _) in enumerate(train_data):
b_x = x.view(-1, 784)
b_y = x.view(-1, 784)
encoded = autoencoder.encoder(b_x)
decoded = autoencoder.decoder(encoded)
loss = loss_func(decoded, b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if step % 100 == 0:
print('Epoch:', epoch, '| train loss: %.4f' % loss.data.numpy())
```
这个自编码器使用MNIST数据集进行训练,包含一个由3个全连接层组成的编码器和一个由3个全连接层组成的解码器。训练过程中使用MSELoss作为损失函数,优化器使用Adam算法进行优化。
训练完成后,可以使用自编码器对数据进行重建和压缩。例如,可以通过以下代码使用自编码器对一张MNIST图片进行压缩和重建:
```python
import matplotlib.pyplot as plt
import numpy as np
# 加载一张MNIST图片
test_data = torchvision.datasets.MNIST(
root='./mnist/',
train=False,
transform=torchvision.transforms.ToTensor(),
download=True
)
test_x = test_data.data.type(torch.FloatTensor)
test_x = test_x.view(-1, 784)
# 使用自编码器进行压缩和重建
encoded = autoencoder.encoder(test_x)
decoded = autoencoder.decoder(encoded)
# 可视化原始图片和重建图片
n = 5
plt.figure(figsize=(10, 4))
for i in range(n):
ax = plt.subplot(2, n, i + 1)
plt.imshow(test_x[i].numpy().reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(2, n, i + 1 + n)
plt.imshow(decoded[i].data.numpy().reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
```
这段代码会显示原始图片和自编码器重建的图片,可以用于观察自编码器的效果。
阅读全文