pytorch 封装数据集
时间: 2023-06-30 13:03:10 浏览: 179
在 PyTorch 中,可以通过封装数据集的方式来方便地进行数据加载和预处理。下面是一个简单的 PyTorch 数据集封装的示例代码:
``` python
import torch
from torch.utils.data import Dataset
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
x = self.data[index][0]
y = self.data[index][1]
return torch.tensor(x), torch.tensor(y)
```
在这个例子中,我们定义了一个名为 MyDataset 的数据集类,该类继承了 PyTorch 的 Dataset 类。在 MyDataset 类的构造函数中,我们传入了一个数据集 data;该数据集可以是一个列表、元组、Numpy 数组或 Pandas 数据帧等数据类型。
在 MyDataset 类中,我们还实现了两个方法:__len__ 和 __getitem__。其中,__len__ 用于返回数据集的长度,而 __getitem__ 用于获取数据集中指定索引位置的数据。在这个例子中,我们假设数据集 data 中每个样本是一个二元组,第一个元素表示输入数据,第二个元素表示标签。在 __getitem__ 方法中,我们将输入数据和标签分别返回为 PyTorch 的 tensor 类型。
有了 MyDataset 类之后,我们可以使用 PyTorch 中的 DataLoader 类来进行数据加载和批处理。下面是一个使用 MyDataset 类和 DataLoader 类的示例代码:
``` python
from torch.utils.data import DataLoader
# 假设我们有一个数据集 data
data = [(1, 2), (3, 4), (5, 6), (7, 8)]
# 创建 MyDataset 对象
dataset = MyDataset(data)
# 创建 DataLoader 对象
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# 迭代 DataLoader 并输出数据
for batch_x, batch_y in dataloader:
print(batch_x, batch_y)
```
在这个例子中,我们首先创建了一个数据集 data,然后使用 MyDataset 类将其封装成一个 PyTorch 数据集对象 dataset。接着,我们使用 DataLoader 类将数据集 dataset 封装成一个数据加载器对象 dataloader。在创建 dataloader 对象时,我们指定了 batch_size=2 和 shuffle=True,表示每个批次的大小为 2,且在每个 epoch 开始时重新打乱数据集。最后,我们通过迭代 dataloader 对象来获取数据批次,并输出其内容。
阅读全文