torch.utils.data.Dataset类的__getitem__函数
时间: 2024-10-14 20:18:43 浏览: 29
`torch.utils.data.Dataset` 类的 `__getitem__` 函数是其核心组件之一,它定义了从数据集中获取单个样本的方式。这个函数的主要作用是:
1. 根据传入的索引(通常是整数),查找并返回对应的数据项。这对于支持随机访问的数据集至关重要,如文件夹中的图片、CSV 文件中的表格行等[^1]。
2. 对于复杂的任务,如图像分类,可以在这个函数内部执行更复杂的操作。例如,如果数据集是图像和标签对,`__getitem__` 可以读取指定索引的图像文件,加载图像到 PyTorch 的张量中,并同时返回相应的标签。示例代码可能如下所示:
```python
class ImageDataset(Dataset):
def __init__(self, image_folder, labels_file):
self.images = os.listdir(image_folder)
self.labels = load_labels(labels_file)
def __getitem__(self, index):
img_path = os.path.join(self.image_folder, self.images[index])
img = Image.open(img_path).convert('RGB') # 加载图像并转换为RGB
label = self.labels[index] # 获取对应标签
return torch.tensor(img), label # 返回PyTorch张量表示的图像和标签
def __len__(self):
return len(self.images)
```
在这里,`__getitem__` 负责读取图片文件,将其转换成适合神经网络输入的张量形式,同时提取对应的标签。
阅读全文