python导入MNIST数据集代码
时间: 2023-07-09 16:31:10 浏览: 123
MNIST包含读取数据代码
以下是Python中导入MNIST数据集的代码:
```python
import gzip
import pickle
import numpy as np
def load_data():
with gzip.open('mnist.pkl.gz', 'rb') as f:
training_data, validation_data, test_data = pickle.load(f, encoding='bytes')
return (training_data, validation_data, test_data)
def load_data_wrapper():
tr_d, va_d, te_d = load_data()
training_inputs = [np.reshape(x, (784, 1)) for x in tr_d[0]]
training_results = [vectorized_result(y) for y in tr_d[1]]
training_data = zip(training_inputs, training_results)
validation_inputs = [np.reshape(x, (784, 1)) for x in va_d[0]]
validation_data = zip(validation_inputs, va_d[1])
test_inputs = [np.reshape(x, (784, 1)) for x in te_d[0]]
test_data = zip(test_inputs, te_d[1])
return (training_data, validation_data, test_data)
def vectorized_result(j):
e = np.zeros((10, 1))
e[j] = 1.0
return e
```
这段代码会读取MNIST数据集并将其分成三个部分:训练数据、验证数据和测试数据。训练数据包括输入和期望输出,验证和测试数据只包括输入。该代码还包括将期望输出向量化的函数。
阅读全文