在手写数字识别的时候,读取数据集的python代码
时间: 2023-12-10 08:40:25 浏览: 68
手写数字识别通常使用的是MNIST数据集,可以使用以下Python代码读取MNIST数据集:
```python
import os
import struct
import numpy as np
def load_mnist(path, kind='train'):
"""Load MNIST data from `path`"""
labels_path = os.path.join(path, f'{kind}-labels-idx1-ubyte')
images_path = os.path.join(path, f'{kind}-images-idx3-ubyte')
with open(labels_path, 'rb') as lbpath:
magic, n = struct.unpack('>II', lbpath.read(8))
labels = np.fromfile(lbpath, dtype=np.uint8)
with open(images_path, 'rb') as imgpath:
magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)
return images, labels
```
其中,`path`参数表示MNIST数据集所在路径,`kind`参数表示需要加载的数据集类型,可以设置为`train`或`test`。这个函数返回两个NumPy数组,分别是图像数据和标签数据。
阅读全文