使用手写数字识别数据集在LeNet-5上增加BN层并比较与全连接、LeNet-5全连接、LeNet-5画出随epoch变化的曲线。 tensorflow代码实现

时间: 2024-03-09 21:49:37 浏览: 24
好的,那我们就使用MNIST手写数字识别数据集。下面是实现的代码: ```python import tensorflow as tf from tensorflow.keras.datasets import mnist from tensorflow.keras.layers import Input, Conv2D, Dense, Flatten, BatchNormalization, Activation, MaxPooling2D from tensorflow.keras.models import Model # 加载数据集 (x_train, y_train), (x_test, y_test) = mnist.load_data() # 归一化数据 x_train, x_test = x_train / 255.0, x_test / 255.0 # 增加维度,将数据变成4D张量 x_train = x_train[..., tf.newaxis].astype("float32") x_test = x_test[..., tf.newaxis].astype("float32") # 定义模型 def LeNet5(): inputs = Input(shape=(28, 28, 1)) x = Conv2D(filters=6, kernel_size=(5, 5), activation='relu', padding='same')(inputs) x = MaxPooling2D(pool_size=(2, 2))(x) x = Conv2D(filters=16, kernel_size=(5, 5), activation='relu', padding='valid')(x) x = MaxPooling2D(pool_size=(2, 2))(x) x = Flatten()(x) x = Dense(units=120, activation='relu')(x) x = Dense(units=84, activation='relu')(x) outputs = Dense(units=10, activation='softmax')(x) model = Model(inputs=inputs, outputs=outputs) return model def LeNet5_BN(): inputs = Input(shape=(28, 28, 1)) x = Conv2D(filters=6, kernel_size=(5, 5), padding='same')(inputs) x = BatchNormalization()(x) x = Activation('relu')(x) x = MaxPooling2D(pool_size=(2, 2))(x) x = Conv2D(filters=16, kernel_size=(5, 5), padding='valid')(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = MaxPooling2D(pool_size=(2, 2))(x) x = Flatten()(x) x = Dense(units=120)(x) x = BatchNormalization()(x) x = Activation('relu')(x) x = Dense(units=84)(x) x = BatchNormalization()(x) x = Activation('relu')(x) outputs = Dense(units=10, activation='softmax')(x) model = Model(inputs=inputs, outputs=outputs) return model def fully_connected(): inputs = Input(shape=(28, 28, 1)) x = Flatten()(inputs) x = Dense(units=128, activation='relu')(x) x = Dense(units=64, activation='relu')(x) outputs = Dense(units=10, activation='softmax')(x) model = Model(inputs=inputs, outputs=outputs) return model # 定义训练参数 batch_size = 64 epochs = 10 # 创建模型 model_LeNet5 = LeNet5() model_LeNet5_BN = LeNet5_BN() model_fully_connected = fully_connected() # 编译模型 model_LeNet5.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) model_LeNet5_BN.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) model_fully_connected.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # 训练模型 history_LeNet5 = model_LeNet5.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test)) history_LeNet5_BN = model_LeNet5_BN.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test)) history_fully_connected = model_fully_connected.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_test, y_test)) # 绘制训练曲线 import matplotlib.pyplot as plt plt.plot(history_LeNet5.history['accuracy'], label='LeNet-5') plt.plot(history_LeNet5_BN.history['accuracy'], label='LeNet-5 + BN') plt.plot(history_fully_connected.history['accuracy'], label='Fully connected') plt.xlabel('Epoch') plt.ylabel('Accuracy') plt.legend() plt.show() ``` 上述代码中,我们定义了三个模型:LeNet-5、LeNet-5+BN和全连接模型。然后,我们将这些模型编译,并使用MNIST数据集进行训练,并绘制出它们在训练集上准确度随epoch变化的曲线。 注意,由于BatchNormalization的存在,我们在准确度图表中应该看到LeNet-5+BN模型的收敛速度更快,并且在相同的epoch数下达到更高的准确度。

相关推荐

最新推荐

recommend-type

手写数字识别:实验报告

AIstudio手写数字识别项目的实验报告,报告中有代码链接。文档包括: 1.数据预处理 2.数据加载 3.网络结构尝试:简单的多层感知器、卷积神经网络LeNet-5、循环神经网络RNN、Vgg16 4.损失函数:平方损失函数、交叉...
recommend-type

组成原理课程实验:MIPS 流水线CPU、实现36条指令、转发、冒险检测-内含源码和说明书.zip

组成原理课程实验:MIPS 流水线CPU、实现36条指令、转发、冒险检测-内含源码和说明书.zip
recommend-type

setuptools-50.0.2-py3-none-any.whl

Python库是一组预先编写的代码模块,旨在帮助开发者实现特定的编程任务,无需从零开始编写代码。这些库可以包括各种功能,如数学运算、文件操作、数据分析和网络编程等。Python社区提供了大量的第三方库,如NumPy、Pandas和Requests,极大地丰富了Python的应用领域,从数据科学到Web开发。Python库的丰富性是Python成为最受欢迎的编程语言之一的关键原因之一。这些库不仅为初学者提供了快速入门的途径,而且为经验丰富的开发者提供了强大的工具,以高效率、高质量地完成复杂任务。例如,Matplotlib和Seaborn库在数据可视化领域内非常受欢迎,它们提供了广泛的工具和技术,可以创建高度定制化的图表和图形,帮助数据科学家和分析师在数据探索和结果展示中更有效地传达信息。
recommend-type

setuptools-1.1.6.tar.gz

Python库是一组预先编写的代码模块,旨在帮助开发者实现特定的编程任务,无需从零开始编写代码。这些库可以包括各种功能,如数学运算、文件操作、数据分析和网络编程等。Python社区提供了大量的第三方库,如NumPy、Pandas和Requests,极大地丰富了Python的应用领域,从数据科学到Web开发。Python库的丰富性是Python成为最受欢迎的编程语言之一的关键原因之一。这些库不仅为初学者提供了快速入门的途径,而且为经验丰富的开发者提供了强大的工具,以高效率、高质量地完成复杂任务。例如,Matplotlib和Seaborn库在数据可视化领域内非常受欢迎,它们提供了广泛的工具和技术,可以创建高度定制化的图表和图形,帮助数据科学家和分析师在数据探索和结果展示中更有效地传达信息。
recommend-type

CEA二次开发脚本:用于ECSP配比设计

CEA二次开发脚本:用于ECSP配比设计
recommend-type

zigbee-cluster-library-specification

最新的zigbee-cluster-library-specification说明文档。
recommend-type

管理建模和仿真的文件

管理Boualem Benatallah引用此版本:布阿利姆·贝纳塔拉。管理建模和仿真。约瑟夫-傅立叶大学-格勒诺布尔第一大学,1996年。法语。NNT:电话:00345357HAL ID:电话:00345357https://theses.hal.science/tel-003453572008年12月9日提交HAL是一个多学科的开放存取档案馆,用于存放和传播科学研究论文,无论它们是否被公开。论文可以来自法国或国外的教学和研究机构,也可以来自公共或私人研究中心。L’archive ouverte pluridisciplinaire
recommend-type

实现实时数据湖架构:Kafka与Hive集成

![实现实时数据湖架构:Kafka与Hive集成](https://img-blog.csdnimg.cn/img_convert/10eb2e6972b3b6086286fc64c0b3ee41.jpeg) # 1. 实时数据湖架构概述** 实时数据湖是一种现代数据管理架构,它允许企业以低延迟的方式收集、存储和处理大量数据。与传统数据仓库不同,实时数据湖不依赖于预先定义的模式,而是采用灵活的架构,可以处理各种数据类型和格式。这种架构为企业提供了以下优势: - **实时洞察:**实时数据湖允许企业访问最新的数据,从而做出更明智的决策。 - **数据民主化:**实时数据湖使各种利益相关者都可
recommend-type

用 Python 画一个可以动的爱心

好的,我可以帮助您使用Python来绘制一个动态的爱心。您可以使用turtle库来实现。以下是实现代码: ```python import turtle import math # 设置画布和画笔 canvas = turtle.Screen() canvas.bgcolor("black") pencil = turtle.Turtle() pencil.speed(0) pencil.color("red", "pink") pencil.pensize(3) # 定义爱心函数 def draw_love(heart_size, x_offset=0, y_offset=0):
recommend-type

JSBSim Reference Manual

JSBSim参考手册,其中包含JSBSim简介,JSBSim配置文件xml的编写语法,编程手册以及一些应用实例等。其中有部分内容还没有写完,估计有生之年很难看到完整版了,但是内容还是很有参考价值的。