请问怎么加载MNIST⼿写数字集
时间: 2024-05-02 15:20:42 浏览: 7
您好,加载MNIST手写数字集可以使用Python中的TensorFlow或者PyTorch库来实现。以下是TensorFlow库的示例代码:
```python
import tensorflow as tf
from tensorflow.keras.datasets import mnist
# 加载数据集
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 数据预处理
x_train, x_test = x_train / 255.0, x_test / 255.0
# 输出数据集的形状
print("训练集的形状:", x_train.shape, y_train.shape)
print("测试集的形状:", x_test.shape, y_test.shape)
```
上述代码首先使用TensorFlow库中的`mnist`模块加载MNIST数据集,然后对数据进行预处理,最后输出训练集和测试集的形状。
当然,如果您使用的是PyTorch库,也可以使用以下代码加载MNIST数据集:
```python
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# 定义数据预处理
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
# 加载数据集
train_set = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_set = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
# 定义数据加载器
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64)
# 输出数据集的形状
print("训练集的形状:", train_set.data.shape, train_set.targets.shape)
print("测试集的形状:", test_set.data.shape, test_set.targets.shape)
```
在上述代码中,我们使用了PyTorch库中的`datasets`和`transforms`模块来加载和预处理MNIST数据集,最后输出训练集和测试集的形状。其中,数据加载器可以帮助我们将数据分成多个batch进行处理,并且可以随机打乱数据的顺序,避免过拟合。