使用CIFAR10数据集实现椒盐噪声的图像去噪并显示图片,要完整代码
时间: 2023-06-13 11:07:18 浏览: 136
下面是使用PyTorch实现椒盐噪声的图像去噪并显示图片的完整代码:
```python
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
import numpy as np
# 超参数
batch_size = 128
learning_rate = 0.001
num_epochs = 10
noise_ratio = 0.2
# 加载数据集
transform = transforms.Compose([transforms.ToTensor()])
train_data = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
test_data = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(dataset=train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_data, batch_size=batch_size, shuffle=False)
# 定义模型
class DenoiseCNN(nn.Module):
def __init__(self):
super(DenoiseCNN, self).__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2),
nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1, bias=False),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1, bias=False),
nn.BatchNorm2d(3),
nn.Sigmoid()
)
def forward(self, x):
encoded = self.encoder(x)
decoded = self.decoder(encoded)
return decoded
model = DenoiseCNN()
# 定义损失函数和优化器
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
# 训练模型
for epoch in range(num_epochs):
for data in train_loader:
img, _ = data
# 添加椒盐噪声
noise = np.random.rand(*img.shape) < noise_ratio
img[noise] = 0.0 if np.random.rand() < 0.5 else 1.0
# 前向传播
output = model(img)
loss = criterion(output, img)
# 反向传播和优化
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 打印训练信息
print('Epoch [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, loss.item()))
# 测试模型
model.eval()
with torch.no_grad():
for data in test_loader:
img, _ = data
# 添加椒盐噪声
noise = np.random.rand(*img.shape) < noise_ratio
img[noise] = 0.0 if np.random.rand() < 0.5 else 1.0
# 前向传播
output = model(img)
# 显示图片
fig, ax = plt.subplots(1, 2)
ax[0].imshow(np.transpose(img[0], (1, 2, 0)))
ax[0].set_title('Original Image')
ax[1].imshow(np.transpose(output[0], (1, 2, 0)))
ax[1].set_title('Denoised Image')
plt.show()
break
```
在训练过程中,我们添加了椒盐噪声来模拟图像噪声。在测试过程中,我们同样添加了椒盐噪声,并将模型的输出与原始图像一起显示,以便进行比较。
阅读全文