帮我写一个DataLoader类
时间: 2024-05-05 17:16:07 浏览: 107
以下是一个基本的DataLoader类的示例代码:
```python
import torch
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, index):
return self.data[index]
def __len__(self):
return len(self.data)
class CustomDataLoader(torch.utils.data.DataLoader):
def __init__(self, data, batch_size=1, shuffle=False, num_workers=0):
dataset = CustomDataset(data)
super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
```
使用方法:
```python
data = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
data_loader = CustomDataLoader(data, batch_size=3, shuffle=True)
for batch in data_loader:
print(batch)
```
输出:
```
tensor([5, 3, 2])
tensor([ 8, 1, 10])
tensor([9, 4, 6])
tensor([7])
```
这个DataLoader类使用了一个自定义的Dataset类`CustomDataset`,它只是一个简单的封装,返回数据的index对应的数据。`CustomDataLoader`继承了`torch.utils.data.DataLoader`类,并重写了初始化函数。在初始化函数中,它首先创建一个`CustomDataset`对象,然后使用`super()`函数调用父类的初始化函数。在后面的代码中,我们可以像使用`torch.utils.data.DataLoader`类一样使用`CustomDataLoader`类。
阅读全文