Pytorch加载自己的数据集(使用DataLoader读取Dataset)
时间: 2023-12-10 16:06:01 浏览: 107
好的,我来给你讲一下如何用 PyTorch 加载自己的数据集。
首先,你需要创建一个自定义的 Dataset 类,它必须包含两个方法:\_\_len\_\_ 和 \_\_getitem\_\_。
```python
import torch.utils.data as data
class MyDataset(data.Dataset):
def __init__(self, data_path):
# 初始化数据集
self.data = []
with open(data_path, 'r') as f:
for line in f:
self.data.append(line.strip())
def __len__(self):
# 返回数据集的长度
return len(self.data)
def __getitem__(self, index):
# 根据索引返回一条数据
return self.data[index]
```
在上面的代码中,我们首先导入了 PyTorch 的 data 模块,然后定义了一个 MyDataset 类。这个类的构造函数需要传入数据集的路径,然后读取数据集并进行初始化。在 \_\_len\_\_ 方法中,我们返回了数据集的长度,这个方法会被 DataLoader 调用以确定数据集的大小。在 \_\_getitem\_\_ 方法中,我们根据索引返回一条数据,这个方法会被 DataLoader 调用以获取数据。
接下来,我们需要创建一个 DataLoader 对象来读取数据集。DataLoader 会按照一定的 batch_size 对数据集进行分批,并且提供数据的迭代器。
```python
from torch.utils.data import DataLoader
# 创建数据集
dataset = MyDataset('data.txt')
# 创建 DataLoader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)
```
在上面的代码中,我们首先创建了一个 MyDataset 对象,然后使用 DataLoader 将这个数据集分批。batch_size 参数指定了每个 batch 中包含的样本数,shuffle 参数指定了是否打乱数据集,num_workers 参数指定了使用多少个子进程来读取数据集。
现在,我们可以通过迭代 DataLoader 来读取数据了。
```python
for batch in dataloader:
# 处理数据
pass
```
在上面的代码中,我们通过迭代 dataloader 来读取数据集,每个 batch 的数据会被封装成一个 Tensor 对象,我们可以直接对这个 Tensor 进行操作。
希望这个回答能够帮助到你!
阅读全文