如何用pytorch自己写一个dataloader
时间: 2023-07-18 17:46:27 浏览: 53
下面是一个简单的例子,展示如何用PyTorch自己写一个dataloader。
1. 首先,需要导入PyTorch的DataLoader和Dataset模块:
```
import torch
from torch.utils.data import DataLoader, Dataset
```
2. 接下来,定义一个自定义的Dataset类,继承自PyTorch的Dataset类,并实现__len__和__getitem__函数:
```
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
```
其中,__init__函数用于初始化数据集,__len__函数用于返回数据集的大小,__getitem__函数用于返回指定索引的数据。
3. 然后,定义一个函数,用于返回一个DataLoader对象:
```
def get_dataloader(data, batch_size):
dataset = CustomDataset(data)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
return dataloader
```
其中,get_dataloader函数接受两个参数:数据集data和批次大小batch_size,返回一个DataLoader对象,其中shuffle参数用于指定是否打乱数据集。
4. 最后,调用get_dataloader函数即可获取一个dataloader:
```
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
batch_size = 2
dataloader = get_dataloader(data, batch_size)
for batch in dataloader:
print(batch)
```
这样就可以得到一个dataloader了。在本例中,数据集是一个简单的数字列表,每个批次包含两个数字,dataloader会将数据集分成多个批次,每次输出一个批次的数据。