python中DataLoader
时间: 2024-04-12 09:08:45 浏览: 100
DataLoader是一个用于批量加载数据的工具,可以对数据进行分批处理、shuffle(随机打乱)等操作。在PyTorch中,DataLoader通常与Dataset一起使用,Dataset用于读取数据,DataLoader用于将数据分批并进行处理。
下面是一个使用DataLoader的例子:
```python
import torch
from torch.utils.data import DataLoader, Dataset
# 自定义Dataset类,实现__len__和__getitem__方法
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
# 定义数据
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
# 创建Dataset实例
dataset = MyDataset(data)
# 创建DataLoader实例,batch_size=2,shuffle=True
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# 迭代DataLoader,每次迭代得到一个batch的数据
for batch in dataloader:
print(batch)
```
运行结果:
```python
tensor([10, 7])
tensor([ 2, 1])
tensor([9, 6])
tensor([5, 8])
tensor([4, 3])
```
可以看到,每次迭代得到一个batch_size=2的数据,且每个batch数据顺序是随机的(shuffle=True)。
阅读全文