换一下读取MNIST数据集的部分
时间: 2023-07-09 18:07:57 浏览: 92
好的,以下是另一种读取MNIST数据集的方法:
```python
import numpy as np
import gzip
def load_mnist(path):
with gzip.open(path + 'train-images-idx3-ubyte.gz', 'rb') as f:
X_train = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28 * 28)
with gzip.open(path + 'train-labels-idx1-ubyte.gz', 'rb') as f:
y_train = np.frombuffer(f.read(), np.uint8, offset=8)
with gzip.open(path + 't10k-images-idx3-ubyte.gz', 'rb') as f:
X_test = np.frombuffer(f.read(), np.uint8, offset=16).reshape(-1, 28 * 28)
with gzip.open(path + 't10k-labels-idx1-ubyte.gz', 'rb') as f:
y_test = np.frombuffer(f.read(), np.uint8, offset=8)
return X_train, y_train, X_test, y_test
```
这个函数使用了`gzip`库来解压缩MNIST数据集文件。`load_mnist`函数接受一个路径作为参数,并返回四个NumPy数组,分别是训练集和测试集的图像和标签数组。其中,`X_train`和`X_test`的形状为`(n_samples, 28 * 28)`,`y_train`和`y_test`的形状为`(n_samples,)`,表示样本数量和每个样本的维度。
阅读全文