Pytorch创建自己的数据集的代码
时间: 2023-09-12 13:08:10 浏览: 45
创建自己的数据集代码可以使用 PyTorch 的 Dataset 和 DataLoader 类来实现。下面是一个简单的示例代码,展示了如何创建一个自定义的数据集类:
```python
import torch
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __getitem__(self, index):
# 根据索引返回数据和标签
x = self.data[index]
y = self.data[index] + 1 # 假设标签是数据加1
return x, y
def __len__(self):
# 返回数据集的大小
return len(self.data)
# 创建自定义数据集对象
data = [1, 2, 3, 4, 5]
dataset = MyDataset(data)
# 创建数据加载器
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)
# 遍历数据加载器
for batch in dataloader:
inputs, labels = batch
print(inputs, labels)
```
在上面的代码中,定义了一个 `MyDataset` 类,继承自 `torch.utils.data.Dataset`。在 `__init__` 方法中初始化数据,`__getitem__` 方法根据索引返回样本数据和标签,`__len__` 方法返回数据集大小。然后,将自定义数据集对象传递给 `DataLoader` 类,可以指定批量大小、是否打乱数据等参数。最后,通过遍历数据加载器可以获取每个批次的输入和标签。
这只是一个简单的示例,你可以根据自己的需求修改代码以适应你的数据集结构和格式。