用pytorch 写一个降噪自动编码器
时间: 2024-05-13 10:16:49 浏览: 130
使用pytorch搭建自编码器,实现图像的去噪
好的,下面是一个简单的降噪自动编码器的 PyTorch 实现代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
class DenoisingAutoencoder(nn.Module):
def __init__(self):
super(DenoisingAutoencoder, self).__init__()
self.encoder = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 12),
nn.ReLU(),
nn.Linear(12, 3))
self.decoder = nn.Sequential(
nn.Linear(3, 12),
nn.ReLU(),
nn.Linear(12, 64),
nn.ReLU(),
nn.Linear(64, 128),
nn.ReLU(),
nn.Linear(128, 784),
nn.Sigmoid())
def forward(self, x):
x = self.encoder(x)
x = self.decoder(x)
return x
def add_noise(data, noise_factor=0.5):
noisy_data = data + noise_factor * torch.randn(data.shape)
return np.clip(noisy_data, 0., 1.)
def train(model, train_loader, optimizer, criterion, epochs):
for epoch in range(epochs):
for data in train_loader:
img, _ = data
img = img.view(img.size(0), -1)
noisy_img = torch.tensor(add_noise(img))
optimizer.zero_grad()
output = model(noisy_img)
loss = criterion(output, img)
loss.backward()
optimizer.step()
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch + 1, epochs, loss.item()))
if __name__ == '__main__':
# 加载 MNIST 数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transforms.ToTensor(), download=True)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=128, shuffle=True)
# 初始化模型、优化器、损失函数
model = DenoisingAutoencoder()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()
# 训练模型
train(model, train_loader, optimizer, criterion, epochs=10)
```
这个代码实现了一个简单的降噪自动编码器,使用的是 MNIST 手写数字数据集。在训练过程中,每个输入图像都会被添加一定程度的高斯噪声,然后输入到自动编码器中进行训练。训练后的自动编码器可以用于去除图像中的噪声。
阅读全文