class SampleDataset(torch.utils.data.Dataset): def __init__(self): self.sequences = [] self.labels = [] for _ in range(1000): seq = torch.randn(10, 5) label = torch.zeros(2) if seq.sum() > 0: label[0] = 1 else: label[1] = 1 self.sequences.append(seq) self.labels.append(label) def __len__(self): return len(self.sequences) def __getitem__(self, idx): return self.sequences[idx], self.labels[idx]
时间: 2024-02-14 15:27:46 浏览: 179
python torch.utils.data.DataLoader使用方法
5星 · 资源好评率100%
这段代码定义了一个名为`SampleDataset`的自定义数据集类,用于生成样本数据。数据集包含了1000个序列样本,每个序列包含10个维度为5的随机数。同时,每个序列对应一个标签,标签是一个维度为2的张量。
`__init__`方法初始化了数据集的`sequences`和`labels`列表。通过循环生成1000个序列样本,并根据序列元素的和来确定标签。如果序列元素的和大于0,则标签的第一个维度为1,否则第二个维度为1。然后将序列和标签添加到对应的列表中。
`__len__`方法返回数据集的样本数量。
`__getitem__`方法根据索引值`idx`返回对应的序列和标签。
这个自定义数据集类可以用于构建PyTorch的数据加载器,并在训练模型时使用。
阅读全文