去噪自编码器pytorch
时间: 2023-05-28 15:06:05 浏览: 95
以下是一个简单的去噪自编码器的 PyTorch 实现:
```python
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor, Normalize
from torchvision.utils import save_image
class DenoisingAutoencoder(nn.Module):
def __init__(self):
super(DenoisingAutoencoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(28 * 28, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU()
)
self.decoder = nn.Sequential(
nn.Linear(32, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 28 * 28),
nn.Sigmoid()
)
def forward(self, x):
x = x.view(-1, 28 * 28)
x = self.encoder(x)
x = self.decoder(x)
x = x.view(-1, 1, 28, 28)
return x
def train(model, dataloader, optimizer, criterion, device):
model.train()
train_loss = 0
for batch_idx, (data, _) in enumerate(dataloader):
data = data.to(device)
noisy_data = data + torch.randn_like(data) * 0.1 # 添加高斯噪声
optimizer.zero_grad()
recon_data = model(noisy_data)
loss = criterion(recon_data, data)
loss.backward()
optimizer.step()
train_loss += loss.item()
return train_loss / len(dataloader)
def test(model, dataloader, criterion, device):
model.eval()
test_loss = 0
with torch.no_grad():
for data, _ in dataloader:
data = data.to(device)
noisy_data = data + torch.randn_like(data) * 0.1 # 添加高斯噪声
recon_data = model(noisy_data)
loss = criterion(recon_data, data)
test_loss += loss.item()
return test_loss / len(dataloader)
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
train_dataset = MNIST(root='./data', train=True, transform=ToTensor(), download=True)
test_dataset = MNIST(root='./data', train=False, transform=ToTensor(), download=True)
train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=4)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False, num_workers=4)
model = DenoisingAutoencoder().to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()
for epoch in range(1, 21):
train_loss = train(model, train_dataloader, optimizer, criterion, device)
test_loss = test(model, test_dataloader, criterion, device)
print('Epoch:', epoch, 'Train Loss:', train_loss, 'Test Loss:', test_loss)
with torch.no_grad():
sample = torch.randn(64, 32).to(device)
sample = model.decoder(sample).cpu()
save_image(sample.view(64, 1, 28, 28), f'sample_{epoch}.png')
```
该模型的结构为一个简单的三层编码器和三层解码器,其中编码器的最后一层输出 32 维的特征向量,解码器的第一层输入这个特征向量。模型的训练过程中每个输入样本会被加上一个高斯噪声,然后使用 MSE 损失函数计算重构误差。在每个 epoch 结束时,会生成一些随机噪声,通过解码器生成样本图片并保存。
阅读全文