def loadData(self,mnist_url): f = gzip.open(str(mnist_url), 'rb') training_data, validation_data, test_data = pickle.load(f, encoding="latin1") f.close() training_inputs = [np.reshape(x, (784, 1)) for x in training_data[0]] #(50000,784,1) training_results = [self.vectorized_result(y) for y in training_data[1]] #(50000,10,1) training_data = zip(training_inputs, training_results) # validation_inputs = [np.reshape(x, (784, 1)) for x in validation_data[0]] # validation_data = zip(validation_inputs, validation_data[1]) # test_inputs = [np.reshape(x, (784, 1)) for x in test_data[0]] test_data = zip(test_inputs, test_data[1]) return (training_inputs,training_results,test_data)
时间: 2024-02-23 15:56:33 浏览: 28
这段代码定义了一个名为"loadData"的函数,它的参数是"mnist_url",表示MNIST数据集的路径。在函数中,首先使用gzip模块打开数据集文件,然后使用pickle模块将数据集文件中的数据读取出来,并将其分为训练数据、验证数据和测试数据。其中,训练数据和验证数据被合并成了一个列表"training_data",测试数据被转换为一个列表"test_data"。接下来,将训练数据中的每个样本按照784维展开成一个784x1的矩阵,并将其存储在"training_inputs"列表中。将训练数据的标签向量化,即将每个标签转换为一个10维的向量,其中对应标签的位置为1,其他位置为0,并将其存储在"training_results"列表中。最后,将"training_inputs"和"training_results"打包成一个元组,并将其作为函数的返回值。注释掉的代码是用来处理验证数据的,因为在这个函数中并没有使用验证数据,所以被注释掉了。
相关问题
import gzip import os import pickle import numpy as np def load_mnist(path, kind='train'): labels_path = os.path.join(path, '%s-labels-idx1-ubyte.gz' % kind) images_path = os.path.join(path, '%s-images-idx3-ubyte.gz' % kind) with gzip.open(labels_path, 'rb') as lbpath: labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8) with gzip.open(images_path, 'rb') as imgpath: images = np.frombuffer(imgpath.read(), dtype=np.uint8, offset=16).reshape(len(labels), 784) return images, labels def get_mnist_data(num_training=5000, num_validation=500, num_test=500): mnist_dir = r'D:\daima\mnist' # 修改为mnist数据集所在的目录 X_train, y_train = load_mnist(mnist_dir, kind='train') X_test, y_test = load_mnist(mnist_dir, kind='t10k') print(X_train.shape) mask = range(num_training, num_training + num_validation) X_val = X_train[mask] y_val = y_train[mask] mask = range(num_training) X_train = X_train[mask] y_train = y_train[mask] mask = range(num_test) X_test = X_test[mask] y_test = y_test[mask] X_train = X_train.astype('float32') / 255 X_val = X_val.astype('float32') / 255 X_test = X_test.astype('float32') / 255 return { 'X_train': X_train, 'y_train': y_train, 'X_val': X_val, 'y_val': y_val, 'X_test': X_test, 'y_test': y_test, },这是读取mnist的函数,如何把解包时给定的值数量从两个增加到4个
在这个函数中,load_mnist函数返回了两个值,即images和labels,但是这个函数只使用了一个返回值,所以只需要将这个函数的返回值改为一个元组即可。具体来说,将return images, labels改为return (images, labels)即可。这样get_mnist_data函数就可以使用解包时给定的值数量为4了。修改后的代码如下:
```
import gzip
import os
import pickle
import numpy as np
def load_mnist(path, kind='train'):
labels_path = os.path.join(path, '%s-labels-idx1-ubyte.gz' % kind)
images_path = os.path.join(path, '%s-images-idx3-ubyte.gz' % kind)
with gzip.open(labels_path, 'rb') as lbpath:
labels = np.frombuffer(lbpath.read(), dtype=np.uint8, offset=8)
with gzip.open(images_path, 'rb') as imgpath:
images = np.frombuffer(imgpath.read(), dtype=np.uint8, offset=16).reshape(len(labels), 784)
return images, labels
def get_mnist_data(num_training=5000, num_validation=500, num_test=500):
mnist_dir = r'D:\daima\mnist' # 修改为mnist数据集所在的目录
X_train, y_train = load_mnist(mnist_dir, kind='train')
X_test, y_test = load_mnist(mnist_dir, kind='t10k')
print(X_train.shape)
mask = range(num_training, num_training + num_validation)
X_val = X_train[mask]
y_val = y_train[mask]
mask = range(num_training)
X_train = X_train[mask]
y_train = y_train[mask]
mask = range(num_test)
X_test = X_test[mask]
y_test = y_test[mask]
X_train = X_train.astype('float32') / 255
X_val = X_val.astype('float32') / 255
X_test = X_test.astype('float32') / 255
return X_train, y_train, X_val, y_val, X_test, y_test
```
import time import tensorflow.compat.v1 as tf tf.disable_v2_behavior() from tensorflow.examples.tutorials.mnist import input_data import mnist_inference import mnist_train tf.compat.v1.reset_default_graph() EVAL_INTERVAL_SECS = 10 def evaluate(mnist): with tf.Graph().as_default() as g: #定义输入与输出的格式 x = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input') y_ = tf.compat.v1.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input') validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels} #直接调用封装好的函数来计算前向传播的结果 y = mnist_inference.inference(x, None) #计算正确率 correcgt_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correcgt_prediction, tf.float32)) #通过变量重命名的方式加载模型 variable_averages = tf.train.ExponentialMovingAverage(0.99) variable_to_restore = variable_averages.variables_to_restore() saver = tf.train.Saver(variable_to_restore) #每隔10秒调用一次计算正确率的过程以检测训练过程中正确率的变化 while True: with tf.compat.v1.Session() as sess: ckpt = tf.train.get_checkpoint_state(minist_train.MODEL_SAVE_PATH) if ckpt and ckpt.model_checkpoint_path: #load the model saver.restore(sess, ckpt.model_checkpoint_path) global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1] accuracy_score = sess.run(accuracy, feed_dict=validate_feed) print("After %s training steps, validation accuracy = %g" % (global_step, accuracy_score)) else: print('No checkpoint file found') return time.sleep(EVAL_INTERVAL_SECS) def main(argv=None): mnist = input_data.read_data_sets(r"D:\Anaconda123\Lib\site-packages\tensorboard\mnist", one_hot=True) evaluate(mnist) if __name__ == '__main__': tf.compat.v1.app.run()对代码进行改进
以下是对代码的改进建议:
1. 在代码开头添加注释,简要说明代码功能和使用方法。
2. 将导入模块的语句放在代码开头。
3. 将模型保存路径和评估时间间隔定义为常量,并使用有意义的变量名。
4. 将计算正确率和加载模型的过程封装为函数。
5. 在主函数中调用评估函数。
改进后的代码如下:
```
# 该代码实现了使用已训练好的模型对 MNIST 数据集进行评估
import time
import tensorflow.compat.v1 as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
import mnist_train
# 定义常量
MODEL_SAVE_PATH = 'model/'
EVAL_INTERVAL_SECS = 10
def evaluate(mnist):
"""
计算模型在验证集上的正确率
"""
with tf.Graph().as_default() as g:
# 定义输入和输出格式
x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE], name='x-input')
y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE], name='y-input')
# 直接调用封装好的函数计算前向传播结果
y = mnist_inference.inference(x, None)
# 计算正确率
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 加载模型
variable_averages = tf.train.ExponentialMovingAverage(mnist_train.MOVING_AVERAGE_DECAY)
variables_to_restore = variable_averages.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
# 在验证集上计算正确率
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
accuracy_score = sess.run(accuracy, feed_dict={x: mnist.validation.images, y_: mnist.validation.labels})
print("After %s training steps, validation accuracy = %g" % (global_step, accuracy_score))
else:
print('No checkpoint file found')
def main(argv=None):
# 读取数据集
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
# 每隔一定时间评估模型在验证集上的正确率
while True:
evaluate(mnist)
time.sleep(EVAL_INTERVAL_SECS)
if __name__ == '__main__':
tf.app.run()
```