使用pytorch编写编写load_data函数用CIFAR10数据
时间: 2024-05-06 14:18:40 浏览: 164
好的,以下是一个使用PyTorch加载CIFAR10数据集的示例代码:
```python
import torch
import torchvision
import torchvision.transforms as transforms
def load_data(batch_size=32):
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
shuffle=False, num_workers=2)
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
return trainloader, testloader, classes
```
在这个函数中,我们使用了PyTorch内置的 torchvision.datasets.CIFAR10 数据集类来加载CIFAR10数据集。我们对数据进行了归一化和转换为张量的操作,然后将训练集和测试集分别加载到 DataLoader 中。最后,我们还定义了 CIFAR10 数据集中类别的名称。
可以使用以下代码来调用该函数:
```python
trainloader, testloader, classes = load_data(batch_size=32)
print('Number of training images: {}'.format(len(trainloader.dataset)))
print('Number of test images: {}'.format(len(testloader.dataset)))
print('Number of classes: {}'.format(len(classes)))
```
这段代码将打印出训练集、测试集和类别的数量。
阅读全文