自编码器代码pytorch
时间: 2023-10-22 14:06:09 浏览: 55
关于自编码器的PyTorch代码,你可以参考以下示例代码:
```python
import torch
import torch.nn as nn
# 定义自编码器模型
class Autoencoder(nn.Module):
def __init__(self, input_dim, encoding_dim):
super(Autoencoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(input_dim, encoding_dim),
nn.ReLU(True)
)
self.decoder = nn.Sequential(
nn.Linear(encoding_dim, input_dim),
nn.Sigmoid()
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
# 创建自编码器实例
input_dim = 784 # 输入数据维度
encoding_dim = 32 # 编码器维度
autoencoder = Autoencoder(input_dim, encoding_dim)
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(autoencoder.parameters(), lr=0.01)
# 训练自编码器
num_epochs = 10
for epoch in range(num_epochs):
for data in dataloader:
inputs, _ = data
optimizer.zero_grad()
outputs = autoencoder(inputs)
loss = criterion(outputs, inputs)
loss.backward()
optimizer.step()
# 使用训练好的自编码器进行重构
reconstructed = autoencoder(inputs)
# 打印重构结果
print(reconstructed)
```