class RegressionDataset(Dataset): def __init__(self, x, y): super().__init__() self.features = x self.targets = y def __getitem__(self, index): x = paddle.to_tensor(self.features[index], dtype='float32') y = paddle.to_tensor(self.targets[index], dtype='float32') return x, y def __len__(self): return len(self.features)
时间: 2023-11-12 07:08:25 浏览: 95
pytorch_自定义Dataset类.docx
这是一个继承自`paddle.io.Dataset`的数据集类`RegressionDataset`,用于存储回归任务的特征和目标。其中`__init__`函数初始化数据集的特征和目标,`__getitem__`函数返回指定索引的特征和目标,并将它们转换为`paddle.Tensor`类型,`__len__`函数返回数据集的长度(即数据样本的数量)。这样定义数据集类可以方便我们在训练模型时使用`paddle.io.DataLoader`进行数据的批量读取和训练。
阅读全文