写一个将fruit360 图像加载到 DataLoader,并探索数据集,打印一个示例图像及其类和标签
时间: 2024-05-24 20:10:02 浏览: 91
Python-Fruits360包含水果和蔬菜的图像数据集
4星 · 用户满意度95%
import torch
from torchvision import transforms, datasets
# 定义数据变换
data_transform = transforms.Compose([
transforms.Resize(224),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# 加载数据集
fruit_dataset = datasets.ImageFolder(root='fruit360', transform=data_transform)
# 创建 DataLoader
batch_size = 32
fruit_dataloader = torch.utils.data.DataLoader(fruit_dataset, batch_size=batch_size, shuffle=True)
# 探索数据集
print('数据集大小:', len(fruit_dataset))
print('类别数:', len(fruit_dataset.classes))
print('类别名称:', fruit_dataset.classes)
# 打印一个示例图像及其类和标签
dataiter = iter(fruit_dataloader)
images, labels = dataiter.next()
print('图像大小:', images.size())
print('标签大小:', labels.size())
print('标签:', labels)
import matplotlib.pyplot as plt
import numpy as np
def imshow(img):
img = img / 2 + 0.5 # 将图像还原为原始范围
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
# 显示一个 batch 的图像
imshow(torchvision.utils.make_grid(images))
# 显示标签
print(' '.join('%5s' % fruit_dataset.classes[labels[j]] for j in range(batch_size)))
阅读全文