加载mnist数据集代码
时间: 2024-10-14 09:00:23 浏览: 19
tensorflow实现加载mnist数据集
加载MNIST手写数字数据集通常是在Python中使用深度学习库如TensorFlow或PyTorch。这里以Keras库为例,因为它是两个库的高级API:
```python
from keras.datasets import mnist
import numpy as np
# 加载数据
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
# 数据预处理
# 将像素值归一化到0-1之间
train_images = train_images.astype('float32') / 255
test_images = test_images.astype('float32') / 255
# 将标签转换为one-hot编码
train_labels = to_categorical(train_labels)
test_labels = to_categorical(test_labels)
# 获取输入维度(28x28图像)
input_shape = (28, 28, 1) if len(train_images.shape) == 3 else (28, 28)
# 如果需要,可以进一步调整数据形状、大小等
```
在这个例子中,`load_data()`函数会返回训练集和测试集的图像数组以及对应的标签。然后对数据进行必要的预处理,如归一化和one-hot编码。
阅读全文