MNIST_DNN: MATLAB实现的简单深度学习系统

需积分: 14 5 下载量 105 浏览量 更新于2024-10-29 收藏 32KB ZIP 举报
MNIST_DNN允许用户轻松地自定义网络参数,例如每层的节点数以及激活函数的类型。当前版本的MNIST_DNN支持三种类型的激活函数,分别是sigmoid、高斯和线性单位。 如何使用MNIST_DNN: 1. 下载MNIST数据集,这是构建深度学习模型常用的基准数据集。它包含了手写数字的灰度图像及其对应的标签。 2. 将下载的MNIST数据集文件放置于MATLAB的data子文件夹中。 3. 运行converter2.m文件,这一步骤通常只需要执行一次。converter2.m文件的作用是将MNIST数据集转换为MNIST_DNN能够使用的格式。 4. 运行mnistdeepauto.m文件或者mnistclassify.m文件,这两种文件分别对应于深度自动编码器和深度分类器的实现。 5. MNIST_DNN的代码是在2013年编写的,作者表明该存储库目前仅作为存档,不建议用于生产环境。对于寻求生产级深度学习解决方案的用户,作者推荐使用Python结合Theano和Keras框架,因为它们在当前的生产环境中更为先进和成熟。 MNIST_DNN使用了MATLAB的深度学习工具箱,因此它依赖于MATLAB环境。MATLAB是一种高性能的数值计算和可视化编程环境,广泛应用于工程、科学以及数学领域。MATLAB中的工具箱提供了多种专门的函数和应用程序,以便于研究者和工程师能够方便地实现复杂的算法和模型。 值得注意的是,MNIST_DNN是一个相对简单和基础的系统,它并没有集成许多当前深度学习研究中常见的高级特性,如卷积层、循环层、批量归一化等。因此,对于需要这些高级特性来解决实际问题的用户来说,使用更为高级的深度学习框架会更加合适。 此外,MNIST_DNN的标签为"系统开源",意味着该软件是开放源代码的,可以被任何用户自由地下载、使用、修改和分发,只要遵守其许可证的规定。这为教育和研究提供了一个很好的起点,用户可以在现有代码的基础上进行扩展和创新。 最后,文件名称MNIST_DNN-master表明这是一个GitHub存储库的主分支。GitHub是目前世界上最流行的代码托管平台,它提供了Git版本控制系统的在线服务。在GitHub上,用户可以轻松地对项目进行版本控制、代码审查、协作开发以及管理软件的发布。"

import time import tensorflow.compat.v1 as tf tf.disable_v2_behavior() from tensorflow.examples.tutorials.mnist import input_data import mnist_inference import mnist_train tf.compat.v1.reset_default_graph() EVAL_INTERVAL_SECS = 10 def evaluate(mnist): with tf.Graph().as_default() as g: #定义输入与输出的格式 x = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input') y_ = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} #直接调用封装好的函数来计算前向传播的结果 y = mnist_inference.inference(x, None) #计算正确率 correcgt_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correcgt_prediction, tf.float32)) #通过变量重命名的方式加载模型 variable_averages = tf.train.ExponentialMovingAverage(0.99) variable_to_restore = variable_averages.variables_to_restore() saver = tf.train.Saver(variable_to_restore) #每隔10秒调用一次计算正确率的过程以检测训练过程中正确率的变化 while True: with tf.compat.v1.Session() as sess: ckpt = tf.train.get_checkpoint_state(minist_train.MODEL_SAVE_PATH) if ckpt and ckpt.model_checkpoint_path: #load the model saver.restore(sess, ckpt.model_checkpoint_path) global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] accuracy_score = sess.run(accuracy, feed_dict=validate_feed) print("After %s training steps, validation accuracy = %g" % (global_step, accuracy_score)) else: print('No checkpoint file found') return time.sleep(EVAL_INTERVAL_SECS) def main(argv=None): mnist = input_data.read_data_sets(r"D:\Anaconda123\Lib\site-packages\tensorboard\mnist", one_hot=True) evaluate(mnist) if __name__ == '__main__': tf.compat.v1.app.run()对代码进行改进

233 浏览量