使用PORTCH采用CIFAR10数据集基于卷积神经网络的图像去噪 椒盐噪声

时间: 2024-01-25 18:06:00 浏览: 28
图像去噪是计算机视觉领域的一个重要研究方向,其目的是通过一些算法将含有噪声的图像变为清晰的图像。在实际应用中,图像常常受到各种类型的噪声干扰,其中椒盐噪声是最常见的一种。本文将介绍如何使用PyTorch框架和CIFAR10数据集基于卷积神经网络实现椒盐噪声图像的去噪。 ## 1. 数据集的准备 CIFAR10数据集是一个经典的图像分类数据集,包含60000张32x32的彩色图像,共分为10个类别,每个类别有6000张图像。在这里,我们将使用CIFAR10数据集中的一部分图像,通过添加椒盐噪声来生成训练集和测试集。 ```python import torch import torchvision from torchvision import transforms import numpy as np import random from PIL import Image # 加载CIFAR10数据集 train_data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True) # 定义添加椒盐噪声的函数 def add_noise(img, noise_type='s&p', SNR=0.1, prob=0.5): """ img: PIL.Image,输入的图像 noise_type: str,噪声类型,可选的有:'gaussian', 'poisson', 's&p',默认为's&p' SNR: float,信噪比,取值范围为[0, 1],默认为0.1 prob: float,噪声添加的概率,取值范围为[0, 1],默认为0.5 """ img = np.array(img) h, w, c = img.shape # 生成噪声 if noise_type == 'gaussian': noise = np.random.normal(0, 1, (h, w, c)) * 255 * (1 - SNR) elif noise_type == 'poisson': noise = np.random.poisson(255 * (1 - SNR), (h, w, c)) / (255 * (1 - SNR)) elif noise_type == 's&p': noise = np.zeros((h, w, c)) # 添加椒盐噪声 for i in range(h): for j in range(w): rand = random.random() if rand < prob: noise[i, j, :] = 0 elif rand > 1 - prob: noise[i, j, :] = 255 else: noise[i, j, :] = img[i, j, :] # 将图像和噪声相加 img_noise = img + noise img_noise = np.clip(img_noise, 0, 255).astype(np.uint8) img_noise = Image.fromarray(img_noise) return img_noise # 定义训练集和测试集 train_transform = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor() ]) test_transform = transforms.Compose([ transforms.ToTensor() ]) train_set = torchvision.datasets.CIFAR10(root='./data', train=True, transform=train_transform, download=True) test_set = torchvision.datasets.CIFAR10(root='./data', train=False, transform=test_transform, download=True) # 添加椒盐噪声 train_noisy_set = [] test_noisy_set = [] for i in range(len(train_set)): x, y = train_set[i] x_noisy = add_noise(x) train_noisy_set.append((x_noisy, y)) for i in range(len(test_set)): x, y = test_set[i] x_noisy = add_noise(x) test_noisy_set.append((x_noisy, y)) # 将数据集转换为DataLoader格式 train_loader = torch.utils.data.DataLoader(train_noisy_set, batch_size=128, shuffle=True, num_workers=4) test_loader = torch.utils.data.DataLoader(test_noisy_set, batch_size=128, shuffle=False, num_workers=4) ``` ## 2. 模型的搭建 在本文中,我们将使用一个简单的卷积神经网络对图像进行去噪。该网络包含多个卷积层和池化层,最后通过全连接层输出去噪后的图像。 ```python import torch.nn as nn class DenoiseNet(nn.Module): def __init__(self): super(DenoiseNet, self).__init__() self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(32) self.relu1 = nn.ReLU(inplace=True) self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2) self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(64) self.relu2 = nn.ReLU(inplace=True) self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2) self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.bn3 = nn.BatchNorm2d(128) self.relu3 = nn.ReLU(inplace=True) self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2) self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1) self.bn4 = nn.BatchNorm2d(256) self.relu4 = nn.ReLU(inplace=True) self.deconv1 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1) self.bn5 = nn.BatchNorm2d(128) self.relu5 = nn.ReLU(inplace=True) self.deconv2 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1) self.bn6 = nn.BatchNorm2d(64) self.relu6 = nn.ReLU(inplace=True) self.deconv3 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1) self.bn7 = nn.BatchNorm2d(32) self.relu7 = nn.ReLU(inplace=True) self.conv5 = nn.Conv2d(32, 3, kernel_size=3, padding=1) def forward(self, x): x = self.conv1(x) x = self.bn1(x) x = self.relu1(x) x = self.pool1(x) x = self.conv2(x) x = self.bn2(x) x = self.relu2(x) x = self.pool2(x) x = self.conv3(x) x = self.bn3(x) x = self.relu3(x) x = self.pool3(x) x = self.conv4(x) x = self.bn4(x) x = self.relu4(x) x = self.deconv1(x) x = self.bn5(x) x = self.relu5(x) x = self.deconv2(x) x = self.bn6(x) x = self.relu6(x) x = self.deconv3(x) x = self.bn7(x) x = self.relu7(x) x = self.conv5(x) return x ``` ## 3. 模型的训练与测试 我们将使用均方误差(MSE)作为损失函数,使用Adam优化器进行参数优化。在每个epoch结束后,我们将会对模型进行一次测试,计算测试集上的损失和准确率。 ```python import torch.optim as optim device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = DenoiseNet().to(device) criterion = nn.MSELoss() optimizer = optim.Adam(model.parameters(), lr=0.001) def train(model, dataloader, criterion, optimizer): model.train() running_loss = 0.0 for i, data in enumerate(dataloader): inputs, _ = data inputs = inputs.to(device) optimizer.zero_grad() outputs = model(inputs) loss = criterion(outputs, inputs) loss.backward() optimizer.step() running_loss += loss.item() * inputs.size(0) return running_loss / len(dataloader.dataset) def test(model, dataloader, criterion): model.eval() running_loss = 0.0 with torch.no_grad(): for i, data in enumerate(dataloader): inputs, _ = data inputs = inputs.to(device) outputs = model(inputs) loss = criterion(outputs, inputs) running_loss += loss.item() * inputs.size(0) return running_loss / len(dataloader.dataset) for epoch in range(20): train_loss = train(model, train_loader, criterion, optimizer) test_loss = test(model, test_loader, criterion) print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Test Loss: {test_loss:.4f}') ``` ## 4. 结果的可视化 最后,我们将使用matplotlib库将原始图像、含有椒盐噪声的图像和去噪后的图像进行可视化展示。 ```python import matplotlib.pyplot as plt # 选择一张测试集中的图像进行展示 index = 0 original, _ = test_set[index] noisy, _ = test_noisy_set[index] clean = model(torch.unsqueeze(noisy, 0).to(device)).detach().cpu() # 将图像转换为PIL.Image格式 original = transforms.functional.to_pil_image(original) noisy = transforms.functional.to_pil_image(noisy) clean = transforms.functional.to_pil_image(torch.squeeze(clean, 0)) # 展示图像 plt.subplot(131) plt.imshow(original) plt.title('Original') plt.axis('off') plt.subplot(132) plt.imshow(noisy) plt.title('Noisy') plt.axis('off') plt.subplot(133) plt.imshow(clean) plt.title('Clean') plt.axis('off') plt.show() ``` 运行程序后,将会显示原始图像、含有椒盐噪声的图像和去噪后的图像,如下图所示: ![denoise.png](https://img-blog.csdnimg.cn/20211009195410971.png)

相关推荐

最新推荐

recommend-type

node-v5.11.1-sunos-x64.tar.xz

Node.js,简称Node,是一个开源且跨平台的JavaScript运行时环境,它允许在浏览器外运行JavaScript代码。Node.js于2009年由Ryan Dahl创立,旨在创建高性能的Web服务器和网络应用程序。它基于Google Chrome的V8 JavaScript引擎,可以在Windows、Linux、Unix、Mac OS X等操作系统上运行。 Node.js的特点之一是事件驱动和非阻塞I/O模型,这使得它非常适合处理大量并发连接,从而在构建实时应用程序如在线游戏、聊天应用以及实时通讯服务时表现卓越。此外,Node.js使用了模块化的架构,通过npm(Node package manager,Node包管理器),社区成员可以共享和复用代码,极大地促进了Node.js生态系统的发展和扩张。 Node.js不仅用于服务器端开发。随着技术的发展,它也被用于构建工具链、开发桌面应用程序、物联网设备等。Node.js能够处理文件系统、操作数据库、处理网络请求等,因此,开发者可以用JavaScript编写全栈应用程序,这一点大大提高了开发效率和便捷性。 在实践中,许多大型企业和组织已经采用Node.js作为其Web应用程序的开发平台,如Netflix、PayPal和Walmart等。它们利用Node.js提高了应用性能,简化了开发流程,并且能更快地响应市场需求。
recommend-type

基于BP用matlab实现车牌识别.zip

基于MATLAB的系统
recommend-type

Java毕业设计-基于SSM框架的学生宿舍管理系统(源码+演示视频+说明).rar

Java毕业设计-基于SSM框架的学生宿舍管理系统(源码+演示视频+说明).rar 【项目技术】 开发语言:Java 框架:ssm+vue 架构:B/S 数据库:mysql 【演示视频-编号:445】 https://pan.quark.cn/s/b3a97032fae7
recommend-type

HTML+CSS+JS小项目集合.zip

html Tab切换 检测浏览器 事件处理 拖拽 Cookie JavaScript模板 canvas canvas画图 canvas路径 WebGL示例 HTML5+CSS3 照片墙 幽灵按钮 综合实例 100du享乐网 高仿小米首页
recommend-type

node-v6.17.1-linux-ppc64.tar.xz

Node.js,简称Node,是一个开源且跨平台的JavaScript运行时环境,它允许在浏览器外运行JavaScript代码。Node.js于2009年由Ryan Dahl创立,旨在创建高性能的Web服务器和网络应用程序。它基于Google Chrome的V8 JavaScript引擎,可以在Windows、Linux、Unix、Mac OS X等操作系统上运行。 Node.js的特点之一是事件驱动和非阻塞I/O模型,这使得它非常适合处理大量并发连接,从而在构建实时应用程序如在线游戏、聊天应用以及实时通讯服务时表现卓越。此外,Node.js使用了模块化的架构,通过npm(Node package manager,Node包管理器),社区成员可以共享和复用代码,极大地促进了Node.js生态系统的发展和扩张。 Node.js不仅用于服务器端开发。随着技术的发展,它也被用于构建工具链、开发桌面应用程序、物联网设备等。Node.js能够处理文件系统、操作数据库、处理网络请求等,因此,开发者可以用JavaScript编写全栈应用程序,这一点大大提高了开发效率和便捷性。 在实践中,许多大型企业和组织已经采用Node.js作为其Web应用程序的开发平台,如Netflix、PayPal和Walmart等。它们利用Node.js提高了应用性能,简化了开发流程,并且能更快地响应市场需求。
recommend-type

zigbee-cluster-library-specification

最新的zigbee-cluster-library-specification说明文档。
recommend-type

管理建模和仿真的文件

管理Boualem Benatallah引用此版本:布阿利姆·贝纳塔拉。管理建模和仿真。约瑟夫-傅立叶大学-格勒诺布尔第一大学,1996年。法语。NNT:电话:00345357HAL ID:电话:00345357https://theses.hal.science/tel-003453572008年12月9日提交HAL是一个多学科的开放存取档案馆,用于存放和传播科学研究论文,无论它们是否被公开。论文可以来自法国或国外的教学和研究机构,也可以来自公共或私人研究中心。L’archive ouverte pluridisciplinaire
recommend-type

实现实时数据湖架构:Kafka与Hive集成

![实现实时数据湖架构:Kafka与Hive集成](https://img-blog.csdnimg.cn/img_convert/10eb2e6972b3b6086286fc64c0b3ee41.jpeg) # 1. 实时数据湖架构概述** 实时数据湖是一种现代数据管理架构,它允许企业以低延迟的方式收集、存储和处理大量数据。与传统数据仓库不同,实时数据湖不依赖于预先定义的模式,而是采用灵活的架构,可以处理各种数据类型和格式。这种架构为企业提供了以下优势: - **实时洞察:**实时数据湖允许企业访问最新的数据,从而做出更明智的决策。 - **数据民主化:**实时数据湖使各种利益相关者都可
recommend-type

机器学习怎么将excel转为csv文件

机器学习是一种利用计算机算法和统计数据的方法来训练计算机来进行自动学习的科学,无法直接将excel文件转为csv文件。但是可以使用Python编程语言来读取Excel文件内容并将其保存为CSV文件。您可以使用Pandas库来读取Excel文件,并使用to_csv()函数将其保存为CSV格式。以下是代码示例: ```python import pandas as pd # 读取 Excel 文件 excel_data = pd.read_excel('example.xlsx') # 将数据保存为 CSV 文件 excel_data.to_csv('example.csv', index=
recommend-type

JSBSim Reference Manual

JSBSim参考手册,其中包含JSBSim简介,JSBSim配置文件xml的编写语法,编程手册以及一些应用实例等。其中有部分内容还没有写完,估计有生之年很难看到完整版了,但是内容还是很有参考价值的。