torch.utils.data.DataLoader函数
时间: 2023-12-11 18:21:50 浏览: 135
torch.utils.data.DataLoader函数是PyTorch中用于加载数据的工具函数之一。它提供了一个简单而高效的数据加载器,用于在训练过程中对数据进行批处理、打乱和并行加载。
DataLoader函数的主要参数包括:
- dataset:表示要加载的数据集,可以是自定义的Dataset类或者已存在的预定义数据集(如torchvision.datasets中的数据集)。
- batch_size:表示每个批次中的样本数量。
- shuffle:表示是否对数据进行打乱操作,以便每个epoch都能得到不同的样本顺序。
- num_workers:表示用于数据加载的子进程数量,可以加速数据加载过程。
- collate_fn:表示用于将样本列表转换为小批量张量的函数,默认使用torch.utils.data.dataloader.default_collate。
- pin_memory:表示是否将数据保存在CUDA固定内存中,可以加速GPU上的数据传输。
使用DataLoader函数可以方便地将数据集加载到模型中进行训练或推断。例如,可以通过以下方式创建一个数据加载器:
```python
from torch.utils.data import DataLoader
# 创建自定义数据集对象
dataset = MyDataset()
# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
```
然后,可以使用for循环遍历数据加载器,并逐个获取每个批次的数据进行训练或推断:
```python
for batch_data in dataloader:
inputs, labels = batch_data
# 在这里进行模型训练或推断
```
通过使用DataLoader函数,可以更加便捷地对大规模数据集进行高效的批处理和并行加载。
阅读全文