torch.utils.data.DataLoader()是干什么的?
时间: 2023-10-16 07:24:03 浏览: 79
`torch.utils.data.DataLoader()` 是一个用于加载数据的工具,它可以将数据集封装成一个迭代器,使得我们可以利用 PyTorch 提供的多进程优化来加速数据加载,并可以批量化地获取数据。它能够自动完成数据集的 shuffle、batch、multiprocessing 等功能,并且支持自定义的数据加载方式。通常情况下,我们可以将自己的数据集封装成 `torch.utils.data.Dataset` 类,然后通过 `DataLoader` 进行数据的加载。例如:
```python
from torch.utils.data import DataLoader, Dataset
class MyDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __getitem__(self, index):
return self.data[index], self.labels[index]
def __len__(self):
return len(self.data)
dataset = MyDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
```
这里我们定义了一个 `MyDataset` 类来封装数据集,然后通过 `DataLoader` 对数据进行批量加载。其中 `batch_size` 表示每个 batch 的大小,`shuffle` 表示是否需要打乱数据集,`num_workers` 表示使用的进程数。通过这样的设置,我们可以方便地对数据进行批量化地处理和加载。
阅读全文