pytorch datasets用法
时间: 2023-11-22 10:43:50 浏览: 89
PyTorch中的datasets模块提供了许多常用数据集的接口,可以方便地加载和处理数据。使用datasets模块,我们可以轻松地将数据集加载到PyTorch中,并进行预处理和转换。
下面是一个使用datasets模块加载CIFAR-10数据集的示例代码:
```python
import torch
import torchvision
import torchvision.transforms as transforms
# 定义数据预处理方式
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 加载训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
# 加载测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
# 定义类别标签
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
```
在上面的代码中,我们首先定义了数据预处理方式,然后使用datasets模块加载了CIFAR-10数据集,并将其分为训练集和测试集。最后,我们定义了类别标签。
阅读全文