pytorch下载MNIST数据集
时间: 2024-09-26 15:00:36 浏览: 43
在PyTorch中下载MNIST手写数字数据集通常涉及使用`torchvision`模块,这是一个方便的数据加载工具包,它包含了各种常用的图像数据集,包括MNIST。以下是下载并预处理MNIST数据集的步骤:
1. **安装 torchvision**:
首先,确保已经安装了PyTorch。如果没有,可以使用pip安装:
```bash
pip install torch torchvision
```
2. **导入所需库并下载数据**:
导入`torchvision.datasets`和`torch.utils.data.DataLoader`等必要的模块,并加载MNIST数据集:
```python
import torch
from torchvision import datasets, transforms
# 数据转换
transform = transforms.Compose([
transforms.ToTensor(), # 将像素值转化为0-1范围
transforms.Normalize((0.5,), (0.5,)) # 归一化到均值=0.5,标准差=0.5
])
# 加载训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
```
这里 `root='./data'` 表示数据将保存在当前目录下的 'data' 文件夹中。如果文件夹不存在,会自动下载。
3. **创建数据加载器**:
使用`DataLoader`函数创建迭代器,以便在训练过程中批处理数据:
```python
batch_size = 64
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
```
现在,你可以使用`train_loader`和`test_loader`来访问训练集和测试集的样本及其对应的标签。
阅读全文