写一个将fruit360 PyTorch 图像加载到 DataLoader,并探索数据集,打印一个示例图像及其类和标签
时间: 2024-05-05 14:17:15 浏览: 103
以下是一个将fruit360 PyTorch 图像加载到 DataLoader 的示例代码:
```python
import torch
from torchvision import transforms, datasets
# 定义数据集路径
data_dir = "path/to/fruit360"
# 定义数据变换
data_transforms = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor()
])
# 加载数据集
image_datasets = datasets.ImageFolder(data_dir, transform=data_transforms)
# 定义数据加载器
data_loader = torch.utils.data.DataLoader(image_datasets, batch_size=32, shuffle=True)
```
要打印一个示例图像及其类和标签,可以使用以下代码:
```python
import matplotlib.pyplot as plt
import numpy as np
# 获取一个批次的数据
images, labels = next(iter(data_loader))
# 获取第一张图像及其类和标签
image = images[0]
label = labels[0]
class_name = image_datasets.classes[label]
# 将图像转换为 NumPy 数组以便可视化
image = image.numpy().transpose((1, 2, 0))
# 显示图像及其类和标签
plt.imshow(image)
plt.title(class_name)
plt.show()
```
阅读全文