如何使用torchvision加载数据集?
时间: 2024-05-02 11:13:50 浏览: 12
使用torchvision加载数据集可以分为以下几个步骤:
1. 导入必要的库:
```python
import torchvision
import torchvision.transforms as transforms
```
2. 定义数据集的转换操作:
```python
transform = transforms.Compose([
transforms.ToTensor(), # 将图像转换为Tensor
transforms.Normalize((0.5,), (0.5,)) # 标准化图像
])
```
3. 加载训练集和测试集:
```python
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
testset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
```
其中,`root`参数指定数据集的存储路径,`train`参数指定是否加载训练集,`download`参数指定是否自动下载数据集,`transform`参数指定对数据集进行的转换操作。
4. 创建数据加载器:
```python
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32, shuffle=True)
testloader = torch.utils.data.DataLoader(testset, batch_size=32, shuffle=False)
```
其中,`batch_size`参数指定每个批次的样本数量,`shuffle`参数指定是否在每个epoch之前对数据进行洗牌。
现在你可以使用`trainloader`和`testloader`来迭代访问训练集和测试集中的样本了。