from torch.io import DataLoader, Dataset
时间: 2024-02-15 08:05:34 浏览: 156
在 PyTorch 中,`DataLoader` 和 `Dataset` 类不属于 `torch.io` 模块,而是属于 `torch.utils.data` 模块。`Dataset` 表示一个数据集,可以通过索引访问其中的元素,`DataLoader` 则用于对数据集进行批量读取和处理。它们的基本用法如下:
```python
import torch.utils.data as data
# 定义数据集类
class MyDataset(data.Dataset):
def __init__(self, data_list):
self.data = data_list
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
# 创建数据集对象
dataset = MyDataset([1, 2, 3, 4, 5])
# 创建数据加载器对象
dataloader = data.DataLoader(dataset, batch_size=2, shuffle=True)
# 使用数据加载器进行迭代
for batch in dataloader:
print(batch)
```
在这个例子中,我们首先定义了一个 `MyDataset` 类,表示一个简单的数据集。在 `__getitem__` 方法中,我们实现了对数据集元素的访问,这里仅仅是返回了数据集中的整数。然后,我们创建了一个 `MyDataset` 对象,并将其传入 `DataLoader` 中。我们指定了 `batch_size=2` 表示每次读取两个数据元素,`shuffle=True` 表示在每个 epoch 开始时对数据集进行随机排序。最后,我们使用 `DataLoader` 对象进行迭代,每次迭代会得到一个大小为 2 的数据批量。
需要注意的是,`Dataset` 类必须实现 `__getitem__` 和 `__len__` 两个方法,其中 `__getitem__` 方法用于返回指定索引的数据元素,`__len__` 方法用于返回数据集的长度。`DataLoader` 在读取数据时会调用 `__getitem__` 方法来获取数据元素,然后通过 `collate_fn` 参数来对数据进行批量处理,最终返回一个批量的数据张量。
阅读全文