如何导入 MNIST 数据集
时间: 2023-12-03 14:39:42 浏览: 114
tensorflow基础教程中所用mnist数据集
以下是两种导入MNIST数据集的方法:
1.使用TensorFlow内置函数导入MNIST数据集
```python
import tensorflow as tf
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
```
2.手动下载MNIST数据集并导入
```python
import os
import gzip
import numpy as np
def load_data():
# 训练集文件
train_images_path = './MNIST_data/train-images-idx3-ubyte.gz'
train_labels_path = './MNIST_data/train-labels-idx1-ubyte.gz'
# 测试集文件
test_images_path = './MNIST_data/t10k-images-idx3-ubyte.gz'
test_labels_path = './MNIST_data/t10k-labels-idx1-ubyte.gz'
# 读取训练集数据
with gzip.open(train_images_path, 'rb') as f:
train_images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)
with gzip.open(train_labels_path, 'rb') as f:
train_labels = np.frombuffer(f.read(), np.uint8, offset=8)
# 读取测试集数据
with gzip.open(test_images_path, 'rb') as f:
test_images = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28, 28)
with gzip.open(test_labels_path, 'rb') as f:
test_labels = np.frombuffer(f.read(), np.uint8, offset=8)
return (train_images, train_labels), (test_images, test_labels)
# 调用函数加载数据
(train_images, train_labels), (test_images, test_labels) = load_data()
```
阅读全文