自编码器pytorch
时间: 2023-07-06 08:32:44 浏览: 90
好的,很高兴回答你关于自编码器的问题。自编码器是一种无监督学习模型,它能够将输入数据压缩成一种称为编码的特征向量,然后通过解码器将编码还原为原始数据。以下是一个使用PyTorch实现自编码器的简单示例代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# 定义自编码器模型
class Autoencoder(nn.Module):
def __init__(self):
super(Autoencoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(True),
nn.Linear(128, 64),
nn.ReLU(True),
nn.Linear(64, 12),
nn.ReLU(True),
nn.Linear(12, 2))
self.decoder = nn.Sequential(
nn.Linear(2, 12),
nn.ReLU(True),
nn.Linear(12, 64),
nn.ReLU(True),
nn.Linear(64, 128),
nn.ReLU(True),
nn.Linear(128, 784),
nn.Tanh())
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
# 加载MNIST数据集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
# 定义训练数据加载器
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
# 实例化自编码器模型
model = Autoencoder()
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
# 训练自编码器模型
for epoch in range(20):
for data in train_loader:
img, _ = data
img = img.view(img.size(0), -1)
optimizer.zero_grad()
output = model(img)
loss = criterion(output, img)
loss.backward()
optimizer.step()
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, 20, loss.item()))
# 使用自编码器模型进行图像重构
import matplotlib.pyplot as plt
images, _ = iter(train_loader).next()
images = images[0:5]
images = images.view(images.size(0), -1)
# 原始图像
plt.figure(figsize=(5, 2))
for i in range(5):
plt.subplot(2, 5, i+1)
plt.imshow(images[i].numpy().reshape(28, 28), cmap='gray')
plt.axis('off')
# 重构图像
with torch.no_grad():
encoded = model.encoder(images)
decoded = model.decoder(encoded)
decoded = decoded.view(decoded.size(0), 1, 28, 28)
plt.figure(figsize=(5, 2))
for i in range(5):
plt.subplot(2, 5, i+1)
plt.imshow(decoded[i].numpy().reshape(28, 28), cmap='gray')
plt.axis('off')
plt.show()
```
在这个示例中,我们使用PyTorch实现了一个简单的自编码器模型,并使用MNIST数据集进行训练和图像重构。通过这个示例,您可以更好地理解自编码器的工作原理和如何在PyTorch中实现它。
阅读全文