torch.utils.data.DataLoader
时间: 2023-12-11 17:05:13 浏览: 83
`torch.utils.data.DataLoader` 是 PyTorch 中用于加载数据的实用工具,它能够方便地对数据进行批量处理和并行化操作。通过 `DataLoader`,你可以将自定义的数据集或已有的数据集对象转换为可迭代的批量数据加载器。
`DataLoader` 提供了以下几个主要功能:
1. 数据加载与处理:可以使用自定义的数据集类或者已有的数据集对象作为输入,在数据加载过程中可以进行各种预处理操作,如数据变换、归一化等。
2. 批量加载:可以指定每个批次(batch)的大小,`DataLoader` 会自动将数据分成多个批次进行加载。
3. 批处理并行化:可以通过设置 `num_workers` 参数,将批处理操作并行化处理,提高数据加载的效率。
4. 数据打乱与重复:可以通过设置 `shuffle` 参数来打乱数据的顺序,增加数据集的随机性。
5. 数据拆分:可以将大型数据集拆分为训练集、验证集和测试集等。
下面是一个简单的示例代码,展示了如何使用 `DataLoader` 加载数据集:
```python
import torch
from torch.utils.data import DataLoader, Dataset
# 自定义数据集类
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
# 创建数据集对象
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)
# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=2)
# 遍历数据加载器
for batch in dataloader:
print(batch)
```
通过以上代码,我们创建了一个自定义的数据集类 `MyDataset`,然后将数据集对象传入 `DataLoader` 中进行批量加载。在遍历数据加载器时,每次迭代会返回一个批次的数据。
阅读全文