用TensorFlow2.0将手写数字集(MNIST)分为训练集,验证集和测试集(6:2:2),进行Tensorflow2的数据加载、多神经网络模型的搭建、分类模型的交叉验证、模型保存和新数据预测(自己手写一个数字拍照到保存的模型中验证)

时间: 2023-06-13 11:07:08 浏览: 38
好的,我可以给您提供一个TensorFlow2.0实现MNIST手写数字分类的代码,包括数据加载、多神经网络模型的搭建、分类模型的交叉验证、模型保存和新数据预测(自己手写一个数字拍照到保存的模型中验证)。 首先,我们需要导入必要的库: ``` import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers import numpy as np import matplotlib.pyplot as plt import cv2 ``` 然后,我们可以加载MNIST数据集并将其分为训练集、验证集和测试集。 ``` (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() # Normalize pixel values to be between 0 and 1 x_train = x_train.astype("float32") / 255 x_test = x_test.astype("float32") / 255 # Split training data into training and validation sets x_train, x_val = x_train[:-12000], x_train[-12000:] y_train, y_val = y_train[:-12000], y_train[-12000:] # Print shapes of training, validation, and test sets print("x_train shape:", x_train.shape) print("y_train shape:", y_train.shape) print("x_val shape:", x_val.shape) print("y_val shape:", y_val.shape) print("x_test shape:", x_test.shape) print("y_test shape:", y_test.shape) ``` 接下来,我们可以搭建多神经网络模型并进行交叉验证。这里我们使用了一个包含两个隐藏层的全连接神经网络。 ``` model = keras.Sequential([ layers.Flatten(input_shape=(28, 28)), layers.Dense(256, activation='relu'), layers.Dropout(0.3), layers.Dense(128, activation='relu'), layers.Dropout(0.2), layers.Dense(10, activation='softmax') ]) model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy']) # Cross-validation history = model.fit(x_train, y_train, epochs=10, batch_size=64, validation_data=(x_val, y_val)) ``` 我们可以使用matplotlib来绘制训练和验证的准确率和损失值。 ``` # Plot training and validation accuracy values plt.plot(history.history['accuracy']) plt.plot(history.history['val_accuracy']) plt.title('Model accuracy') plt.ylabel('Accuracy') plt.xlabel('Epoch') plt.legend(['Train', 'Val'], loc='upper left') plt.show() # Plot training and validation loss values plt.plot(history.history['loss']) plt.plot(history.history['val_loss']) plt.title('Model loss') plt.ylabel('Loss') plt.xlabel('Epoch') plt.legend(['Train', 'Val'], loc='upper left') plt.show() ``` 当我们的模型训练好之后,我们可以使用以下代码来保存模型。 ``` model.save('mnist_model.h5') ``` 最后,我们可以使用以下代码来预测我们自己手写的数字。 ``` # Load saved model model = keras.models.load_model('mnist_model.h5') # Read image and resize img = cv2.imread('test_image.png', cv2.IMREAD_GRAYSCALE) img = cv2.resize(img, (28, 28)) # Invert colors and normalize pixel values to be between 0 and 1 img = cv2.bitwise_not(img) img = img.astype("float32") / 255 # Reshape image to match model input shape img = np.reshape(img, (1, 28, 28)) # Make prediction prediction = model.predict(img) # Print prediction print("Prediction:", np.argmax(prediction[0])) ``` 这样,我们就完成了使用TensorFlow2.0对MNIST手写数字集进行分类,并且可以使用我们训练好的模型来预测我们自己手写的数字。

相关推荐

最新推荐

recommend-type

使用tensorflow实现VGG网络,训练mnist数据集方式

主要介绍了使用tensorflow实现VGG网络,训练mnist数据集方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

tensorflow实现残差网络方式(mnist数据集)

主要介绍了tensorflow实现残差网络方式(mnist数据集),具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

基于多分类非线性SVM(+交叉验证法)的MNIST手写数据集训练(无框架)算法

2.通过一对一方法将45类训练样本((0,1),(0,2),…(1,2)…(2,3))送入交叉验证法,训练算法为smo 3.得出45个模型,测试时在利用投票法判定 数据结构 '''***********************************************************...
recommend-type

Pytorch使用MNIST数据集实现CGAN和生成指定的数字方式

今天小编就为大家分享一篇Pytorch使用MNIST数据集实现CGAN和生成指定的数字方式,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
recommend-type

基于TensorFlow的CNN实现Mnist手写数字识别

本文实例为大家分享了基于TensorFlow的CNN实现Mnist手写数字识别的具体代码,供大家参考,具体内容如下 一、CNN模型结构 输入层:Mnist数据集(28*28) 第一层卷积:感受视野5*5,步长为1,卷积核:32个 第一层...
recommend-type

zigbee-cluster-library-specification

最新的zigbee-cluster-library-specification说明文档。
recommend-type

管理建模和仿真的文件

管理Boualem Benatallah引用此版本:布阿利姆·贝纳塔拉。管理建模和仿真。约瑟夫-傅立叶大学-格勒诺布尔第一大学,1996年。法语。NNT:电话:00345357HAL ID:电话:00345357https://theses.hal.science/tel-003453572008年12月9日提交HAL是一个多学科的开放存取档案馆,用于存放和传播科学研究论文,无论它们是否被公开。论文可以来自法国或国外的教学和研究机构,也可以来自公共或私人研究中心。L’archive ouverte pluridisciplinaire
recommend-type

MATLAB柱状图在信号处理中的应用:可视化信号特征和频谱分析

![matlab画柱状图](https://img-blog.csdnimg.cn/3f32348f1c9c4481a6f5931993732f97.png) # 1. MATLAB柱状图概述** MATLAB柱状图是一种图形化工具,用于可视化数据中不同类别或组的分布情况。它通过绘制垂直条形来表示每个类别或组中的数据值。柱状图在信号处理中广泛用于可视化信号特征和进行频谱分析。 柱状图的优点在于其简单易懂,能够直观地展示数据分布。在信号处理中,柱状图可以帮助工程师识别信号中的模式、趋势和异常情况,从而为信号分析和处理提供有价值的见解。 # 2. 柱状图在信号处理中的应用 柱状图在信号处理
recommend-type

HSV转为RGB的计算公式

HSV (Hue, Saturation, Value) 和 RGB (Red, Green, Blue) 是两种表示颜色的方式。下面是将 HSV 转换为 RGB 的计算公式: 1. 将 HSV 中的 S 和 V 值除以 100,得到范围在 0~1 之间的值。 2. 计算色相 H 在 RGB 中的值。如果 H 的范围在 0~60 或者 300~360 之间,则 R = V,G = (H/60)×V,B = 0。如果 H 的范围在 60~120 之间,则 R = ((120-H)/60)×V,G = V,B = 0。如果 H 的范围在 120~180 之间,则 R = 0,G = V,B =
recommend-type

JSBSim Reference Manual

JSBSim参考手册,其中包含JSBSim简介,JSBSim配置文件xml的编写语法,编程手册以及一些应用实例等。其中有部分内容还没有写完,估计有生之年很难看到完整版了,但是内容还是很有参考价值的。