pytorch加载数据集的代码
时间: 2024-04-12 16:25:00 浏览: 207
当使用PyTorch加载数据集时,通常会使用`torchvision.datasets`模块中的类来加载常见的数据集,例如MNIST、CIFAR等。以下是一个加载MNIST数据集的示例代码:
```python
import torch
from torchvision import datasets, transforms
# 定义数据预处理的转换操作
transform = transforms.Compose([
transforms.ToTensor(), # 将图像转换为Tensor
transforms.Normalize((0.5,), (0.5,)) # 标准化图像数据
])
# 加载训练集和测试集
train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform, download=True)
# 创建数据加载器
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)
```
上述代码首先导入了必要的库,然后定义了数据预处理的转换操作,包括将图像转换为Tensor和标准化图像数据。接下来使用`datasets.MNIST`类加载MNIST数据集,指定了数据集的存储路径、是否为训练集、预处理转换操作和是否下载数据集。最后,通过`torch.utils.data.DataLoader`创建了训练集和测试集的数据加载器,可以指定批量大小和是否打乱数据。
阅读全文