torch.utils.data.DataLoader
时间: 2023-07-23 07:13:35 浏览: 90
`torch.utils.data.DataLoader` 是 PyTorch 中用于加载数据的工具类,它可以将数据集封装成一个迭代器,方便我们进行批量读取数据和多进程加载数据等操作。
使用 `DataLoader` 类需要传入三个参数:数据集(`dataset`)、批大小(`batch_size`)和是否启用多进程加载数据的标志(`num_workers`)。
`DataLoader` 类的一些常用参数:
- `dataset`:数据集对象。必须实现 `__getitem__()` 和 `__len__()` 方法,可以使用 PyTorch 中提供的 `torch.utils.data.Dataset` 类或自己实现数据集类。
- `batch_size`:每个批次加载的样本数量。
- `shuffle`:是否对数据进行随机洗牌。
- `num_workers`:使用多少个子进程来加载数据。默认值为 `0`,表示不使用多进程。
- `collate_fn`:指定如何将样本列表转换为批次张量。默认情况下,它会将样本列表转换为张量,并将其堆叠在一起。如果数据集中的每个样本大小不一致,则需要自定义此函数。
例如,以下代码使用 `DataLoader` 类加载一个名为 `MyDataset` 的数据集,并设置批大小为 64,使用 4 个子进程来加载数据:
```python
from torch.utils.data import DataLoader
from my_dataset import MyDataset
dataset = MyDataset()
dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)
```
`dataloader` 对象是一个迭代器,可以使用 `for` 循环遍历数据集中的每个批次。每个批次的数据是一个长度为 `batch_size` 的张量,即一个形状为 `(batch_size, ...)` 的多维数组。需要注意的是,最后一个批次的大小可能会小于 `batch_size`,需要特别处理。
阅读全文