mnist数据集处理的代码
时间: 2023-10-08 07:13:42 浏览: 42
以下是一个简单的Python代码,用于加载和处理MNIST数据集:
```python
import numpy as np
import mnist # pip install mnist
# 加载MNIST数据集
train_images = mnist.train_images() # 训练集图像
train_labels = mnist.train_labels() # 训练集标签
test_images = mnist.test_images() # 测试集图像
test_labels = mnist.test_labels() # 测试集标签
# 数据预处理
# 将图像像素值标准化为0到1之间的浮点数
train_images = (train_images / 255) - 0.5
test_images = (test_images / 255) - 0.5
# 将图像展平为向量
train_images = train_images.reshape((-1, 784))
test_images = test_images.reshape((-1, 784))
# 打印数据集维度
print(train_images.shape) # (60000, 784)
print(train_labels.shape) # (60000,)
print(test_images.shape) # (10000, 784)
print(test_labels.shape) # (10000,)
```
此代码使用MNIST Python包加载数据集,并对数据进行了一些预处理,如将像素值标准化为0到1之间的浮点数,并将图像展平为向量。