使用PORTCH采用CIFAR10数据集基于卷积神经网络的图像去噪 椒盐噪声
时间: 2024-01-25 13:06:00 浏览: 131
【图像去噪】基于快速有效多级选择性卷积滤波器去除椒盐噪声附matlab代码 上传.zip
图像去噪是计算机视觉领域的一个重要研究方向,其目的是通过一些算法将含有噪声的图像变为清晰的图像。在实际应用中,图像常常受到各种类型的噪声干扰,其中椒盐噪声是最常见的一种。本文将介绍如何使用PyTorch框架和CIFAR10数据集基于卷积神经网络实现椒盐噪声图像的去噪。
## 1. 数据集的准备
CIFAR10数据集是一个经典的图像分类数据集,包含60000张32x32的彩色图像,共分为10个类别,每个类别有6000张图像。在这里,我们将使用CIFAR10数据集中的一部分图像,通过添加椒盐噪声来生成训练集和测试集。
```python
import torch
import torchvision
from torchvision import transforms
import numpy as np
import random
from PIL import Image
# 加载CIFAR10数据集
train_data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True)
# 定义添加椒盐噪声的函数
def add_noise(img, noise_type='s&p', SNR=0.1, prob=0.5):
"""
img: PIL.Image,输入的图像
noise_type: str,噪声类型,可选的有:'gaussian', 'poisson', 's&p',默认为's&p'
SNR: float,信噪比,取值范围为[0, 1],默认为0.1
prob: float,噪声添加的概率,取值范围为[0, 1],默认为0.5
"""
img = np.array(img)
h, w, c = img.shape
# 生成噪声
if noise_type == 'gaussian':
noise = np.random.normal(0, 1, (h, w, c)) * 255 * (1 - SNR)
elif noise_type == 'poisson':
noise = np.random.poisson(255 * (1 - SNR), (h, w, c)) / (255 * (1 - SNR))
elif noise_type == 's&p':
noise = np.zeros((h, w, c))
# 添加椒盐噪声
for i in range(h):
for j in range(w):
rand = random.random()
if rand < prob:
noise[i, j, :] = 0
elif rand > 1 - prob:
noise[i, j, :] = 255
else:
noise[i, j, :] = img[i, j, :]
# 将图像和噪声相加
img_noise = img + noise
img_noise = np.clip(img_noise, 0, 255).astype(np.uint8)
img_noise = Image.fromarray(img_noise)
return img_noise
# 定义训练集和测试集
train_transform = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor()
])
test_transform = transforms.Compose([
transforms.ToTensor()
])
train_set = torchvision.datasets.CIFAR10(root='./data', train=True, transform=train_transform, download=True)
test_set = torchvision.datasets.CIFAR10(root='./data', train=False, transform=test_transform, download=True)
# 添加椒盐噪声
train_noisy_set = []
test_noisy_set = []
for i in range(len(train_set)):
x, y = train_set[i]
x_noisy = add_noise(x)
train_noisy_set.append((x_noisy, y))
for i in range(len(test_set)):
x, y = test_set[i]
x_noisy = add_noise(x)
test_noisy_set.append((x_noisy, y))
# 将数据集转换为DataLoader格式
train_loader = torch.utils.data.DataLoader(train_noisy_set, batch_size=128, shuffle=True, num_workers=4)
test_loader = torch.utils.data.DataLoader(test_noisy_set, batch_size=128, shuffle=False, num_workers=4)
```
## 2. 模型的搭建
在本文中,我们将使用一个简单的卷积神经网络对图像进行去噪。该网络包含多个卷积层和池化层,最后通过全连接层输出去噪后的图像。
```python
import torch.nn as nn
class DenoiseNet(nn.Module):
def __init__(self):
super(DenoiseNet, self).__init__()
self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
self.bn1 = nn.BatchNorm2d(32)
self.relu1 = nn.ReLU(inplace=True)
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
self.bn2 = nn.BatchNorm2d(64)
self.relu2 = nn.ReLU(inplace=True)
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
self.bn3 = nn.BatchNorm2d(128)
self.relu3 = nn.ReLU(inplace=True)
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
self.bn4 = nn.BatchNorm2d(256)
self.relu4 = nn.ReLU(inplace=True)
self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1)
self.bn5 = nn.BatchNorm2d(128)
self.relu5 = nn.ReLU(inplace=True)
self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1)
self.bn6 = nn.BatchNorm2d(64)
self.relu6 = nn.ReLU(inplace=True)
self.deconv3 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1)
self.bn7 = nn.BatchNorm2d(32)
self.relu7 = nn.ReLU(inplace=True)
self.conv5 = nn.Conv2d(32, 3, kernel_size=3, padding=1)
def forward(self, x):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu1(x)
x = self.pool1(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu2(x)
x = self.pool2(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu3(x)
x = self.pool3(x)
x = self.conv4(x)
x = self.bn4(x)
x = self.relu4(x)
x = self.deconv1(x)
x = self.bn5(x)
x = self.relu5(x)
x = self.deconv2(x)
x = self.bn6(x)
x = self.relu6(x)
x = self.deconv3(x)
x = self.bn7(x)
x = self.relu7(x)
x = self.conv5(x)
return x
```
## 3. 模型的训练与测试
我们将使用均方误差(MSE)作为损失函数,使用Adam优化器进行参数优化。在每个epoch结束后,我们将会对模型进行一次测试,计算测试集上的损失和准确率。
```python
import torch.optim as optim
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DenoiseNet().to(device)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
def train(model, dataloader, criterion, optimizer):
model.train()
running_loss = 0.0
for i, data in enumerate(dataloader):
inputs, _ = data
inputs = inputs.to(device)
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, inputs)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
return running_loss / len(dataloader.dataset)
def test(model, dataloader, criterion):
model.eval()
running_loss = 0.0
with torch.no_grad():
for i, data in enumerate(dataloader):
inputs, _ = data
inputs = inputs.to(device)
outputs = model(inputs)
loss = criterion(outputs, inputs)
running_loss += loss.item() * inputs.size(0)
return running_loss / len(dataloader.dataset)
for epoch in range(20):
train_loss = train(model, train_loader, criterion, optimizer)
test_loss = test(model, test_loader, criterion)
print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}')
```
## 4. 结果的可视化
最后,我们将使用matplotlib库将原始图像、含有椒盐噪声的图像和去噪后的图像进行可视化展示。
```python
import matplotlib.pyplot as plt
# 选择一张测试集中的图像进行展示
index = 0
original, _ = test_set[index]
noisy, _ = test_noisy_set[index]
clean = model(torch.unsqueeze(noisy, 0).to(device)).detach().cpu()
# 将图像转换为PIL.Image格式
original = transforms.functional.to_pil_image(original)
noisy = transforms.functional.to_pil_image(noisy)
clean = transforms.functional.to_pil_image(torch.squeeze(clean, 0))
# 展示图像
plt.subplot(131)
plt.imshow(original)
plt.title('Original')
plt.axis('off')
plt.subplot(132)
plt.imshow(noisy)
plt.title('Noisy')
plt.axis('off')
plt.subplot(133)
plt.imshow(clean)
plt.title('Clean')
plt.axis('off')
plt.show()
```
运行程序后,将会显示原始图像、含有椒盐噪声的图像和去噪后的图像,如下图所示:
![denoise.png](https://img-blog.csdnimg.cn/20211009195410971.png)
阅读全文