pytorch加载图片数据集
时间: 2023-09-08 16:15:19 浏览: 107
在 PyTorch 中,可以使用 `torchvision` 库来加载常用的图像数据集,同时也提供了数据的预处理方法。
下面是一个加载 CIFAR-10 数据集的示例代码:
```python
import torch
import torchvision
import torchvision.transforms as transforms
# 定义数据预处理方式
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 加载训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
# 加载测试集
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
# 类别标签
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
```
在代码中, `transforms` 模块提供了一系列对数据进行预处理的方法,这里使用了 `ToTensor()` 将图像转换为张量,并使用 `Normalize()` 对数据进行归一化处理。
使用 `torchvision.datasets` 模块可以加载常用的数据集,如 CIFAR-10,ImageNet 等。 `DataLoader` 则提供了对数据进行批处理的功能,可以通过设置 `batch_size` 来指定每个批次的大小。
最后,我们可以通过 `classes` 变量来获取不同类别的标签信息。
阅读全文