def loadData(self,mnist_url): f = gzip.open(str(mnist_url), 'rb') training_data, validation_data, test_data = pickle.load(f, encoding="latin1") f.close() training_inputs = [np.reshape(x, (784, 1)) for x in training_data[0]] #(50000,784,1) training_results = [self.vectorized_result(y) for y in training_data[1]] #(50000,10,1) training_data = zip(training_inputs, training_results) # validation_inputs = [np.reshape(x, (784, 1)) for x in validation_data[0]] # validation_data = zip(validation_inputs, validation_data[1]) # test_inputs = [np.reshape(x, (784, 1)) for x in test_data[0]] test_data = zip(test_inputs, test_data[1]) return (training_inputs,training_results,test_data)
时间: 2024-02-23 11:56:33 浏览: 81
mnist.rar_68B9_MNIST_MNIST 数据集_site:en.pudn.com_tensenflow
这段代码定义了一个名为"loadData"的函数,它的参数是"mnist_url",表示MNIST数据集的路径。在函数中,首先使用gzip模块打开数据集文件,然后使用pickle模块将数据集文件中的数据读取出来,并将其分为训练数据、验证数据和测试数据。其中,训练数据和验证数据被合并成了一个列表"training_data",测试数据被转换为一个列表"test_data"。接下来,将训练数据中的每个样本按照784维展开成一个784x1的矩阵,并将其存储在"training_inputs"列表中。将训练数据的标签向量化,即将每个标签转换为一个10维的向量,其中对应标签的位置为1,其他位置为0,并将其存储在"training_results"列表中。最后,将"training_inputs"和"training_results"打包成一个元组,并将其作为函数的返回值。注释掉的代码是用来处理验证数据的,因为在这个函数中并没有使用验证数据,所以被注释掉了。
阅读全文