PyTorch读取读取Cifar数据集并显示图片的实例讲解数据集并显示图片的实例讲解
今天小编就为大家分享一篇PyTorch读取Cifar数据集并显示图片的实例讲解,具有很好的参考价值,希望对大家
有所帮助。一起跟随小编过来看看吧
首先了解一下需要的几个类所在的首先了解一下需要的几个类所在的package
from torchvision import transforms, datasets as ds
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
#transform = transforms.Compose是把一系列图片操作组合起来,比如减去像素均值等。
#DataLoader读入的数据类型是PIL.Image
#这里对图片不做任何处理,仅仅是把PIL.Image转换为torch.FloatTensor,从而可以被pytorch计算
transform = transforms.Compose(
[
transforms.ToTensor()
]
)
Step 1,得到,得到torch.utils.data.Dataset实例。实例。
torch.utils.data.Dataset是一个抽象类,CIFAR100是它的一个实例化子类
train=True,读取训练集;train=False,读取测试集
download=False,不下载。如果为True,则先检查root下有无该数据集,如果没有就先下载。
train_set = ds.CIFAR100(root='.', train=True, transform=transform, target_transform=None, download=True)
Step 2,把,把Dataset封装成封装成torch.utils.data.DataLoader
data_loader = DataLoader(dataset=train_set,
batch_size=1,
shuffle=False,
num_workers=2)
# # 生成torch.utils.data.DataLoaderIter
# # 不过DataLoaderIter它会被DataLoader自动创建并且调用,我们用不到
# data_iter = iter(data_loader)
# images, labels = next(data_iter)
step 3,从,从DataLoader里读取数据,并将图片显示出来。里读取数据,并将图片显示出来。