使用torchvision 加载并预处理CIFAR-10 数据集。
时间: 2024-05-08 22:17:47 浏览: 106
以下是使用torchvision加载并预处理CIFAR-10数据集的示例代码:
```python
import torchvision
import torchvision.transforms as transforms
# 定义预处理转换
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# 加载训练集和测试集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=2)
# 定义类别标签
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
```
这里,我们定义了一个名为“transform”的预处理转换,它将图像转换为PyTorch张量,并对其进行归一化。然后,我们使用`torchvision.datasets.CIFAR10`加载训练集和测试集,并将其传递给`torch.utils.data.DataLoader`,以便我们可以对其进行迭代。最后,我们定义了CIFAR-10数据集的类别标签。
阅读全文