请分别给我一个自编码器AE和反向自编码器DAE的例子,用pytorch实现
时间: 2024-05-11 12:17:43 浏览: 5
自编码器AE的例子:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
# 定义自编码器
class AE(nn.Module):
def __init__(self, input_size, hidden_size):
super(AE, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
)
self.decoder = nn.Sequential(
nn.Linear(hidden_size, input_size),
nn.Sigmoid(),
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
# 定义训练函数
def train(model, train_loader, epochs, lr):
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
running_loss = 0.0
for data in train_loader:
inputs, _ = data
inputs = inputs.view(inputs.size(0), -1)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, inputs)
loss.backward()
optimizer.step()
running_loss += loss.item()
print('Epoch [%d], loss: %.4f' % (epoch+1, running_loss/len(train_loader)))
# 加载数据集
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
train_dataset = MNIST(root='data/', train=True, transform=ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# 训练自编码器
ae = AE(784, 256)
train(ae, train_loader, 10, 0.001)
# 测试自编码器
import matplotlib.pyplot as plt
test_dataset = MNIST(root='data/', train=False, transform=ToTensor(), download=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)
with torch.no_grad():
for data in test_loader:
inputs, _ = data
inputs = inputs.view(inputs.size(0), -1)
outputs = ae(inputs)
break
inputs = inputs.numpy()
outputs = outputs.numpy()
fig, ax = plt.subplots(1, 2, figsize=(8, 4))
ax[0].imshow(np.reshape(inputs[0], (28, 28)), cmap='gray')
ax[0].set_title('Input')
ax[1].imshow(np.reshape(outputs[0], (28, 28)), cmap='gray')
ax[1].set_title('Output')
plt.show()
```
反向自编码器DAE的例子:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
# 定义反向自编码器
class DAE(nn.Module):
def __init__(self, input_size, hidden_size):
super(DAE, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(input_size, hidden_size),
nn.ReLU(),
)
self.decoder = nn.Sequential(
nn.Linear(hidden_size, input_size),
nn.Sigmoid(),
)
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
# 定义训练函数
def train(model, train_loader, epochs, lr):
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
running_loss = 0.0
for data in train_loader:
inputs, _ = data
inputs = inputs.view(inputs.size(0), -1)
# 对输入进行随机噪声处理
inputs_noisy = inputs + torch.randn_like(inputs) * 0.2
inputs_noisy = torch.clamp(inputs_noisy, 0., 1.)
optimizer.zero_grad()
outputs = model(inputs_noisy)
loss = criterion(outputs, inputs)
loss.backward()
optimizer.step()
running_loss += loss.item()
print('Epoch [%d], loss: %.4f' % (epoch+1, running_loss/len(train_loader)))
# 加载数据集
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
train_dataset = MNIST(root='data/', train=True, transform=ToTensor(), download=True)
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
# 训练反向自编码器
dae = DAE(784, 256)
train(dae, train_loader, 10, 0.001)
# 测试反向自编码器
import matplotlib.pyplot as plt
test_dataset = MNIST(root='data/', train=False, transform=ToTensor(), download=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=True)
with torch.no_grad():
for data in test_loader:
inputs, _ = data
inputs = inputs.view(inputs.size(0), -1)
inputs_noisy = inputs + torch.randn_like(inputs) * 0.2
inputs_noisy = torch.clamp(inputs_noisy, 0., 1.)
outputs = dae(inputs_noisy)
break
inputs = inputs.numpy()
inputs_noisy = inputs_noisy.numpy()
outputs = outputs.numpy()
fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax[0].imshow(np.reshape(inputs[0], (28, 28)), cmap='gray')
ax[0].set_title('Input')
ax[1].imshow(np.reshape(inputs_noisy[0], (28, 28)), cmap='gray')
ax[1].set_title('Input with noise')
ax[2].imshow(np.reshape(outputs[0], (28, 28)), cmap='gray')
ax[2].set_title('Output')
plt.show()
```