深度学习框架pytorch通过卷积神经网络实现mnist手写数据集的识别,卷积层数、卷积核数、卷积核大小、padding值、stride值、学习率、优化方式等均可自己选择,计算模型精确率,训练误差,测试误差,并以图表形式显示

时间: 2025-01-27 08:15:20 浏览: 19

好的,下面是一个使用PyTorch框架通过卷积神经网络(CNN)实现MNIST手写数字数据集识别的示例代码。这个示例包括卷积层数、卷积核数、卷积核大小、padding值、stride值、学习率、优化方式等参数的设置,并计算模型精确率、训练误差和测试误差,最后以图表形式显示结果。

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt

# 定义超参数
batch_size = 64
learning_rate = 0.01
num_epochs = 10

# 数据预处理
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

# 加载MNIST数据集
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# 定义卷积神经网络
class ConvNet(nn.Module):
    def __init__(self):
        super(ConvNet, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.layer2 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3, padding=1, stride=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )
        self.fc = nn.Linear(7 * 7 * 64, 10)
        
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = out.reshape(out.size(0), -1)
        out = self.fc(out)
        return out

model = ConvNet()

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# 训练模型
total_step = len(train_loader)
train_losses = []
test_losses = []
for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for images, labels in train_loader:
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    avg_train_loss = running_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    
    # 测试模型
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        test_loss = 0.0
        for images, labels in test_loader:
            outputs = model(images)
            loss = criterion(outputs, labels)
            test_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
        
        avg_test_loss = test_loss / len(test_loader)
        test_losses.append(avg_test_loss)
        accuracy = 100 * correct / total
    
    print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {avg_train_loss:.4f}, Test Loss: {avg_test_loss:.4f}, Accuracy: {accuracy:.2f}%')

# 绘制训练和测试误差图
plt.plot(range(1, num_epochs+1), train_losses, label='Train Loss')
plt.plot(range(1, num_epochs+1), test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.show()
向AI提问 loading 发送消息图标

相关推荐

大家在看

recommend-type

使用Arduino监控ECG和呼吸-项目开发

使用TI出色的ADS1292R芯片连接Arduino,以查看您的ECG,呼吸和心率。
recommend-type

航空发动机缺陷检测数据集VOC+YOLO格式291张4类别.7z

数据集格式:Pascal VOC格式+YOLO格式(不包含分割路径的txt文件,仅仅包含jpg图片以及对应的VOC格式xml文件和yolo格式txt文件) 图片数量(jpg文件个数):291 标注数量(xml文件个数):291 标注数量(txt文件个数):291 标注类别数:4 标注类别名称:[“crease”,“damage”,“dot”,“scratch”] 更多信息:blog.csdn.net/FL1623863129/article/details/139274954
recommend-type

python基础教程:pandas DataFrame 行列索引及值的获取的方法

pandas DataFrame是二维的,所以,它既有列索引,又有行索引 上一篇里只介绍了列索引: import pandas as pd df = pd.DataFrame({'A': [0, 1, 2], 'B': [3, 4, 5]}) print df # 结果: A B 0 0 3 1 1 4 2 2 5 行索引自动生成了 0,1,2 如果要自己指定行索引和列索引,可以使用 index 和 column 参数: 这个数据是5个车站10天内的客流数据: ridership_df = pd.DataFrame( data=[[ 0, 0, 2, 5, 0],
recommend-type

【微电网优化】基于粒子群优化IEEE经典微电网结构附matlab代码.zip

1.版本:matlab2014/2019a,内含运行结果,不会运行可私信 2.领域:智能优化算法、神经网络预测、信号处理、元胞自动机、图像处理、路径规划、无人机等多种领域的Matlab仿真,更多内容可点击博主头像 3.内容:标题所示,对于介绍可点击主页搜索博客 4.适合人群:本科,硕士等教研学习使用 5.博客介绍:热爱科研的Matlab仿真开发者,修心和技术同步精进,matlab项目合作可si信
recommend-type

三层神经网络模型matlab版

纯手写三层神经网络,有数据,无需其他函数,直接运行,包括batchBP和singleBP。

最新推荐

recommend-type

Pytorch使用MNIST数据集实现CGAN和生成指定的数字方式

在本教程中,我们将探讨如何使用PyTorch框架来实现条件生成对抗网络(CGAN)并利用MNIST数据集生成指定数字的图像。CGAN是一种扩展了基础生成对抗网络(GAN)的概念,它允许在生成过程中加入额外的条件信息,如类...
recommend-type

Pytorch实现的手写数字mnist识别功能完整示例

在本示例中,我们将讨论如何使用Pytorch实现手写数字的识别,特别是针对MNIST数据集。MNIST数据集包含了60000个训练样本和10000个测试样本,每个样本都是28x28像素的手写数字图像。 首先,我们需要导入必要的库,...
recommend-type

pytorch 利用lstm做mnist手写数字识别分类的实例

在本实例中,我们将探讨如何使用PyTorch构建一个基于LSTM(长短期记忆网络)的手写数字识别模型,以解决MNIST数据集的问题。MNIST数据集包含大量的手写数字图像,通常用于训练和测试计算机视觉算法,尤其是深度学习...
recommend-type

使用卷积神经网络(CNN)做人脸识别的示例代码

卷积神经网络(CNN)在人脸识别领域的应用已经成为现代计算机视觉技术的重要组成部分。相较于早期的人脸识别算法,如特征脸法,CNN以其强大的特征提取能力和自动学习能力,显著提升了人脸识别的准确性和效率。特征脸...
recommend-type

【深度学习入门】Paddle实现手写数字识别详解(基于DenseNet)

1. **MNIST数据集**:MNIST是广泛用于手写数字识别的经典数据集,由60,000张训练图像和10,000张测试图像组成。每个图像都是28x28像素的灰度图像,对应的标签是0到9之间的整数。它是初学者入门深度学习的首选任务,...
recommend-type

全面介绍酒店设施的培训纲要

从提供的信息来看,可以推断这是一份关于酒店设施培训的纲要文档,虽然具体的文件内容并未提供,但是可以从标题和描述中提炼一些相关知识点和信息。 首先,关于标题“酒店《酒店设施》培训活动纲要”,我们可以得知该文档的内容是关于酒店行业的培训,培训内容专注于酒店的设施使用和管理。培训活动纲要作为一项计划性文件,通常会涉及以下几个方面: 1. 培训目标:这可能是文档中首先介绍的部分,明确培训的目的是为了让员工熟悉并掌握酒店各项设施的功能、操作以及维护等。目标可以是提高员工服务效率、增强客户满意度、确保设施安全运行等。 2. 培训对象:该培训可能针对的是酒店内所有需要了解或操作酒店设施的员工,比如前台接待、客房服务员、工程技术人员、维修人员等。 3. 培训内容:这应该包括了酒店设施的详细介绍,比如客房内的家具、电器,公共区域的休闲娱乐设施,健身房、游泳池等体育设施,以及会议室等商务设施。同时,也可能会涉及到设备的使用方法、安全规范、日常维护、故障排查等。 4. 培训方式:这部分会说明是通过什么形式进行培训的,如现场操作演示、视频教学、文字说明、模拟操作、考核测试等。 5. 培训时间:这可能涉及培训的总时长、分阶段的时间表、各阶段的时间分配以及具体的培训日期等。 6. 培训效果评估:介绍如何评估培训效果,可能包括员工的反馈、考试成绩、实际操作能力的测试、工作中的应用情况等。 再来看描述,提到该文档“是一份很不错的参考资料,具有较高参考价值”,说明这个培训纲要经过整理,能够为酒店行业的人士提供实用的信息和指导。这份纲要可能包含了经过实践检验的最佳实践,以及专家们总结的经验和技巧,这些都是员工提升技能、提升服务质量的宝贵资源。 至于“感兴趣可以下载看看”,这表明该培训纲要对有兴趣了解酒店管理、特别是酒店设施管理的人士开放,这可能意味着纲要内容足够通俗易懂,即使是没有酒店行业背景的人员也能够从中获益。 虽然文件标签没有提供,但是结合标题和描述,我们可以推断标签可能与“酒店管理”、“设施操作”、“员工培训”、“服务技能提升”、“安全规范”等有关。 最后,“【下载自www.glzy8.com管理资源吧】酒店《酒店设施》培训活动纲要.doc”表明了文件来源和文件格式。"www.glzy8.com"很可能是一个提供管理资源下载的网站,其中"glzy"可能是对“管理资源”的缩写,而".doc"格式则说明这是一个Word文档,用户可以通过点击链接下载使用。 总结来说,虽然具体文件内容未知,但是通过提供的标题和描述,我们可以了解到该文件是一个酒店行业内部使用的设施培训纲要,它有助于提升员工对酒店设施的理解和操作能力,进而增强服务质量和客户满意度。而文件来源网站,则显示了该文档具有一定的行业共享性和实用性。
recommend-type

Qt零基础到精通系列:全面提升轮播图开发技能的15堂必修课

# 摘要 本文全面探讨了基于Qt框架的轮播图开发技术。文章首先介绍了Qt框架的基本安装、配置和图形用户界面的基础知识,重点讨论了信号与槽机制以及Widgets组件的使用。接着深入分析了轮播图的核心机制,包括工作原理、关键技术点和性能优化策略。在此基础上,文章详细阐述了使用Qt
recommend-type

创建的conda环境无法配置到pycharm

### 配置 Conda 虚拟环境到 PyCharm 的方法 在 PyCharm 中配置已创建的 Conda 虚拟环境可以通过以下方式实现: #### 方法一:通过新建 Python 工程的方式配置 当您创建一个新的 Python 工程时,可以按照以下流程完成 Conda 环境的配置: 1. 创建一个新项目,在弹出窗口中找到 **Python Interpreter** 设置区域。 2. 点击右侧的齿轮图标并选择 **Add...** 来添加新的解释器。 3. 在弹出的对话框中选择 **Conda Environment** 选项卡[^1]。 4. 如果尚未安装 Conda 或未检测到其路
recommend-type

Java与JS结合实现动态下拉框搜索提示功能

标题中的“java+js实现下拉框提示搜索功能”指的是一种在Web开发中常用的功能,即当用户在输入框中输入文本时,系统能够实时地展示一个下拉列表,其中包含与用户输入相关联的数据项。这个过程是动态的,意味着用户每输入一个字符,下拉列表就会更新一次,从而加快用户的查找速度并提升用户体验。此功能通常用在搜索框或者表单字段中。 描述中提到的“在输入框中输入信息,会出现下拉框列出符合条件的数据,实现动态的查找功能”具体指的是这一功能的实现方法。具体实现方式通常涉及前端技术JavaScript,可能还会结合后端技术Java,以及Ajax技术来获取数据并动态更新页面内容。 关于知识点的详细说明: 1. JavaScript基础 JavaScript是一种客户端脚本语言,用于实现前端页面的动态交互和数据处理。实现下拉框提示搜索功能需要用到的核心JavaScript技术包括事件监听、DOM操作、数据处理等。其中,事件监听可以捕捉用户输入时的动作,DOM操作用于动态创建或更新下拉列表元素,数据处理则涉及对用户输入的字符串进行匹配和筛选。 2. Ajax技术 Ajax(Asynchronous JavaScript and XML)是一种在无需重新加载整个页面的情况下,能够与服务器交换数据并更新部分网页的技术。利用Ajax,可以在用户输入数据时异步请求服务器端的Java接口,获取匹配的搜索结果,然后将结果动态插入到下拉列表中。这样用户体验更加流畅,因为整个过程不需要重新加载页面。 3. Java后端技术 Java作为后端开发语言,常用于处理服务器端逻辑。实现动态查找功能时,Java主要承担的任务是对数据库进行查询操作。根据Ajax请求传递的用户输入参数,Java后端通过数据库查询接口获取数据,并将查询结果以JSON或其他格式返回给前端。 4. 实现步骤 - 创建输入框,并为其绑定事件监听器(如keyup事件)。 - 当输入框中的文本变化时,触发事件处理函数。 - 事件处理函数中通过Ajax向后端发送请求,并携带输入框当前的文本作为查询参数。 - 后端Java接口接收到请求后,根据传入参数在数据库中执行查询操作。 - 查询结果通过Java接口返回给前端。 - 前端JavaScript接收到返回的数据后,更新页面上显示的下拉列表。 - 显示的下拉列表应能反映当前输入框中的文本内容,随着用户输入实时变化。 5. 关键技术细节 - **前端数据绑定和展示**:在JavaScript中处理Ajax返回的数据,并通过DOM操作技术更新下拉列表元素。 - **防抖和节流**:为输入框绑定的事件处理函数可能过于频繁触发,可能会导致服务器负载过重。因此,实际实现中通常会引入防抖(debounce)和节流(throttle)技术来减少请求频率。 - **用户体验优化**:下拉列表需要按匹配度排序,并且要处理大量数据时的显示问题,以保持良好的用户体验。 6. 安全和性能考虑 - **数据过滤和验证**:前端对用户输入应该进行适当过滤和验证,防止SQL注入等安全问题。 - **数据的加载和分页**:当数据量很大时,应该采用分页或其他技术来减少一次性加载的数据量,避免页面卡顿。 - **数据缓存**:对于经常查询且不常变动的数据,可以采用前端缓存来提高响应速度。 在文件名称列表中提到的"Ajax",实际上是一个关键的技术要点。实现动态下拉框提示功能往往需要将JavaScript和Ajax配合使用,实现页面的异步数据更新。这里的Ajax文件可能包含用于处理数据异步加载逻辑的JavaScript代码。 通过以上知识点的详细阐述,可以清晰了解java和js结合实现下拉框提示搜索功能的技术原理和实现步骤。这涉及到前端JavaScript编程、后端Java编程、Ajax数据交互、以及前后端数据处理和展示等多方面的技术细节。掌握这些技术能够有效地在Web应用中实现交互式的动态下拉框提示功能。
recommend-type

【LVGL快速入门与精通】:10个实用技巧,让你从新手到专家

# 摘要 LVGL(Light and Versatile Graphics Library)是一个开源的嵌入式图形库,专为资源受限的嵌入式系统设计。本文全面介绍LVGL图形库,探讨其核心概念、基础及高级应用技巧,以及如何在嵌入式系统中实现复杂的用户界面和优化用户体验。文章还分析了LVGL与硬件的集成方法、
手机看
程序员都在用的中文IT技术交流社区

程序员都在用的中文IT技术交流社区

专业的中文 IT 技术社区,与千万技术人共成长

专业的中文 IT 技术社区,与千万技术人共成长

关注【CSDN】视频号,行业资讯、技术分享精彩不断,直播好礼送不停!

关注【CSDN】视频号,行业资讯、技术分享精彩不断,直播好礼送不停!

客服 返回
顶部