编写dataloader类
时间: 2024-06-13 14:06:55 浏览: 92
为了编写一个dataloader类,我们需要先定义一个继承自Dataset类的数据集类,然后再使用这个数据集类来创建dataloader对象。下面是一个简单的例子:
```python
import torch
from torch.utils.data import Dataset, DataLoader
# 定义数据集类
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
# 创建数据集对象
my_data = [1, 2, 3, 4, 5]
dataset = MyDataset(my_data)
# 创建dataloader对象
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# 使用dataloader迭代数据
for batch in dataloader:
print(batch)
```
在上面的例子中,我们首先定义了一个数据集类`MyDataset`,它继承自`Dataset`类,并实现了`__getitem__`和`__len__`方法。`__getitem__`方法用于获取指定index的数据,`__len__`方法用于返回数据集的长度。
然后我们创建了一个数据集对象`dataset`,并将其传递给`DataLoader`类的构造函数中,同时指定了`batch_size`和`shuffle`参数。最后我们使用`for`循环和`dataloader`对象迭代数据,每次迭代返回一个batch的数据。
阅读全文