pytorch读取部分cifar10
时间: 2023-11-02 10:06:02 浏览: 119
引用中提到,PyTorch中读取CIFAR数据集的常用方法是使用torch.utils.data.DataLoader接口。该接口可以将数据按照指定的batch size封装成Tensor,并作为模型的输入。下面是一个示例代码,用于读取部分CIFAR10数据集:
```python
import torch
import torchvision
import torchvision.transforms as transforms
# 定义数据预处理的方法
transform = transforms.Compose([
transforms.ToTensor(), # 将图像转换为Tensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 对图像进行标准化处理
])
# 载入CIFAR10训练集
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4, shuffle=True, num_workers=2)
# 读取部分CIFAR10数据集
for i, data in enumerate(trainloader, 0):
inputs, labels = data
if i < 5: # 只读取前5个batch的数据
print(f"Batch {i+1}:")
print("Inputs:", inputs)
print("Labels:", labels)
print()
else:
break
```
阅读全文