Data.DataLoader
时间: 2023-10-20 10:32:11 浏览: 84
Data.DataLoader是PyTorch中用于数据批处理的工具。它可以将数据包装成小批量进行训练。在使用DataLoader时,我们需要提供一个数据集(dataset)和一个批大小(batch_size)来指定每个批次的样本数量。此外,我们还可以选择是否对数据进行随机洗牌(shuffle)以及是否在数据样本数量不能整除批大小时舍弃最后一批数据(drop_last)。
下面是一个使用DataLoader进行数据批处理的示例代码:
```
import torch
import torch.utils.data as Data
BATCH_SIZE = 5
x = torch.linspace(1, 11, 11)
y = torch.linspace(11, 1, 11)
torch_dataset = Data.TensorDataset(x, y)
loader = Data.DataLoader(
dataset=torch_dataset,
batch_size=BATCH_SIZE,
shuffle=True,
drop_last=True,
)
def show_batch():
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
# 在这里进行训练操作
print("step:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))
if __name__ == '__main__':
show_batch()
```
在上述代码中,我们首先定义了一个大小为5的批次(BATCH_SIZE),然后创建了一个TensorDataset对象,将x和y数据合并为一个数据集。接下来,我们使用DataLoader将数据集包装成一个迭代器,每次迭代产生一个批次的数据。在训练过程中,我们可以通过遍历迭代器来获取每个批次的数据,并进行相应的训练操作。<span class="em">1</span><span class="em">2</span><span class="em">3</span>
#### 引用[.reference_title]
- *1* *2* [PyTorch之torch.utils.data.DataLoader详解](https://blog.csdn.net/qq_36044523/article/details/118914223)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"]
- *3* [python torch.utils.data.DataLoader使用方法](https://download.csdn.net/download/weixin_38605590/12853596)[target="_blank" data-report-click={"spm":"1018.2226.3001.9630","extra":{"utm_source":"vip_chatgpt_common_search_pc_result","utm_medium":"distribute.pc_search_result.none-task-cask-2~all~insert_cask~default-1-null.142^v93^chatsearchT3_2"}}] [.reference_item style="max-width: 50%"]
[ .reference_list ]
阅读全文