mnist数据集的划分代码
时间: 2023-07-09 12:48:07 浏览: 88
下面是使用Python中的sklearn库中的train_test_split函数来实现MNIST数据集划分的示例代码:
```python
from sklearn.model_selection import train_test_split
from tensorflow.keras.datasets import mnist
# 加载MNIST数据集
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# 将训练集划分为训练集和验证集,比例为60%训练集,20%验证集,20%测试集
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=42)
X_test, y_test = X_test / 255.0, y_test
# 输出划分后的数据集大小
print('X_train shape:', X_train.shape)
print('y_train shape:', y_train.shape)
print('X_val shape:', X_val.shape)
print('y_val shape:', y_val.shape)
print('X_test shape:', X_test.shape)
print('y_test shape:', y_test.shape)
```
这里使用了sklearn库中的train_test_split函数将训练集划分为训练集和验证集,同时保留20%的测试集。我们使用了mnist.load_data()函数来加载MNIST数据集,然后将数据集划分为训练集、验证集和测试集,并将输入数据归一化处理,最后输出数据集的大小。
阅读全文