python load_mnist函数代码
时间: 2024-05-04 07:16:25 浏览: 110
以下是一个基于TensorFlow的Python代码示例,可以加载MNIST数据集。
```
import tensorflow as tf
def load_mnist():
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255.0
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255.0
y_train = tf.keras.utils.to_categorical(y_train, 10)
y_test = tf.keras.utils.to_categorical(y_test, 10)
return (x_train, y_train), (x_test, y_test)
```
该函数首先使用Keras库中的`mnist.load_data()`函数从TensorFlow包中下载MNIST数据集。然后,它将训练和测试数据分别存储在`x_train`、`y_train`和`x_test`、`y_test`变量中。
接着,该函数将输入数据的维度从`(num_samples, 28, 28)`变为`(num_samples, 28, 28, 1)`,其中`num_samples`是训练或测试样本的数量。这是因为TensorFlow要求输入数据的形状为`(batch_size, height, width, channels)`。
最后,函数将类别标签转换为独热编码向量,并返回`((x_train, y_train), (x_test, y_test))`元组。
阅读全文