pytorch的数据集制作
时间: 2023-09-28 19:07:28 浏览: 31
PyTorch的数据集制作可以通过继承`torch.utils.data.Dataset`类来实现。以下是一个简单的数据集制作示例:
```python
import torch
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
x = self.data[index][0]
y = self.data[index][1]
return torch.tensor(x, dtype=torch.float), torch.tensor(y, dtype=torch.float)
```
在上面的代码中,我们定义了一个名为`CustomDataset`的类,它继承了`torch.utils.data.Dataset`类。在`__init__`中,我们将数据传递给数据集,`__len__`方法返回数据集的长度,`__getitem__`方法返回给定索引处的数据。
对于每个数据样本,我们将输入和输出转换为PyTorch张量并返回它们。这个数据集可以用来训练和测试模型。
使用示例:
```python
data = [(1, 2), (3, 4), (5, 6), (7, 8)]
dataset = CustomDataset(data)
print(len(dataset)) # 输出:4
x, y = dataset[0]
print(x, y) # 输出:tensor(1.) tensor(2.)
```