mnist数据集加载
时间: 2024-12-27 11:28:33 浏览: 11
### 如何在 Python 中加载 MNIST 数据集
#### 使用 PyTorch 加载 MNIST 数据集
为了使用 PyTorch 进行训练,`train_data` 和 `test_data` 的类型应转换为 torch 类型的数据结构。具体来说,在 PyTorch 中可以利用内置的 torchvision 库来轻松下载和预处理 MNIST 数据集。
```python
import torch
from torchvision import datasets, transforms
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
# 下载并加载训练数据
trainset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
# 下载并加载测试数据
testset = datasets.MNIST('~/.pytorch/MNIST_data/', download=True, train=False, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=True)
```
这段代码展示了如何通过 torchvision 来获取 MNIST 数据集,并将其转化为适合输入神经网络的形式[^1]。
#### 使用 TensorFlow 加载 MNIST 数据集
对于 TensorFlow 用户而言,同样可以通过 Keras API 轻松访问 MNIST 数据集:
```python
import tensorflow as tf
from tensorflow.keras.datasets import mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 对图像数据进行标准化操作
x_train, x_test = x_train / 255.0, x_test / 255.0
```
此部分代码说明了怎样借助 TensorFlow 提供的功能迅速取得 MNIST 图像及其标签,并执行必要的预处理工作[^3]。
#### 使用其他库加载 MNIST 数据集
除了上述两种主流框架外,还可以采用更底层的方式读取原始 CSV 文件形式存储的 MNIST 数据集。这种方法适用于那些希望绕过高级API直接接触基础数据的情况。
```python
import pandas as pd
df_train = pd.read_csv('mnist_data/mnist_train.csv', header=None)
df_test = pd.read_csv('mnist_data/mnist_test.csv', header=None)
X_train = df_train.iloc[:, 1:].values.astype('float32') / 255.
y_train = df_train.iloc[:, 0].values.astype('int')
X_test = df_test.iloc[:, 1:].values.astype('float32') / 255.
y_test = df_test.iloc[:, 0].values.astype('int')
```
这里给出了当 MNIST 数据被保存成 CSV 文件时的一种导入方法[^2]。
阅读全文