pytorch dataloader怎么加载图像数据对应的标签
时间: 2024-05-09 21:19:17 浏览: 101
在 PyTorch 中,可以通过自定义 Dataset 类和 DataLoader 类来加载图像数据和对应的标签。以下是一个示例代码:
```python
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
class MyDataset(Dataset):
def __init__(self, img_paths, labels):
self.img_paths = img_paths
self.labels = labels
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
img_path = self.img_paths[idx]
label = self.labels[idx]
img = Image.open(img_path).convert('RGB')
# 可以添加图像预处理操作,例如裁剪、缩放、归一化等
return img, label
img_paths = ['/path/to/image1.jpg', '/path/to/image2.jpg', ...]
labels = [0, 1, ...]
dataset = MyDataset(img_paths, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
```
在上述代码中,`MyDataset` 类继承了 `torch.utils.data.Dataset` 类,并重写了 `__len__` 和 `__getitem__` 方法,分别用于返回数据集的长度和某个索引位置的数据样本。在 `__getitem__` 方法中,根据索引位置读取对应的图像路径和标签,并使用 PIL 库中的 `Image.open` 方法加载图像数据。可以根据需要添加图像预处理操作。最后,通过 `DataLoader` 类将数据集包装成批量数据加载器,可以设置批量大小和是否随机打乱数据顺序。
阅读全文