pytorch 自制的数据集怎么载入
时间: 2023-11-13 09:09:41 浏览: 67
pytorch 自定义数据集加载方法
5星 · 资源好评率100%
载入自制数据集的方法可以按照以下步骤进行:
1. 定义数据集类:自制数据集需要继承 `torch.utils.data.Dataset` 类,并实现 `__len__` 和 `__getitem__` 方法。
2. 加载数据集:使用 `torch.utils.data.DataLoader` 类可以加载数据集,设置 batch_size、shuffle 等参数。
下面是一个简单的例子:
```python
import torch
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
# 创建数据集
data = [1, 2, 3, 4, 5]
my_dataset = MyDataset(data)
# 加载数据集
dataloader = DataLoader(my_dataset, batch_size=2, shuffle=True)
# 迭代数据集
for batch in dataloader:
print(batch)
```
在这个例子中,我们创建了一个包含数字 1~5 的数据集 `MyDataset`,使用 `DataLoader` 加载数据集,并设置 batch_size=2、shuffle=True。最后,我们通过迭代 dataloader 来获取数据集的 batch,并打印出来。
阅读全文