torch.utils.data.DataLoader
时间: 2023-10-16 11:17:31 浏览: 72
torch.utils.data.DataLoader是PyTorch中的一个数据加载器,它可以方便地对数据进行批处理、打乱顺序、并行加载等操作。一般来说,我们使用DataLoader将数据处理成一个一个的batch,然后送入模型进行训练或测试。DataLoader的常用参数包括:
- dataset:数据集,一般是继承自torch.utils.data.Dataset的自定义类。
- batch_size:每个batch的大小。
- shuffle:是否打乱数据顺序。
- num_workers:用于数据加载的子进程数。
- drop_last:如果数据集的大小不能被batch_size整除,是否舍弃最后一个batch。
使用DataLoader的示例代码如下:
```python
import torch
from torch.utils.data import DataLoader, Dataset
class CustomDataset(Dataset):
def __init__(self):
self.x = torch.tensor([[1, 2], [3, 4], [5, 6], [7, 8]])
self.y = torch.tensor([0, 1, 0, 1])
def __len__(self):
return self.x.shape[0]
def __getitem__(self, idx):
return self.x[idx], self.y[idx]
dataset = CustomDataset()
dataloader = DataLoader(dataset, batch_size=2, shuffle=True, num_workers=2)
for x, y in dataloader:
print(x, y)
```
以上代码中,我们自定义了一个数据集CustomDataset,并将其传入DataLoader中。在使用DataLoader时,它会将CustomDataset中的数据按照batch_size进行划分,并在每个batch内进行打乱顺序等操作。在for循环中,我们可以直接遍历DataLoader,每次输出一个batch的数据。
阅读全文