pytorch 划分数据集
时间: 2023-10-18 16:09:47 浏览: 87
在 PyTorch 中,可以使用 `torch.utils.data.random_split` 函数来将数据集划分为训练集和验证集。
下面是一个例子:
```python
import torch
from torch.utils.data import Dataset, DataLoader, random_split
# 自定义数据集类
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
# 假设有100个样本数据
data = list(range(100))
# 创建数据集对象
dataset = MyDataset(data)
# 将数据集划分为训练集和验证集
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=True)
```
在上面的例子中,我们首先定义了一个自定义的数据集类 `MyDataset`,然后将数据集划分为训练集和验证集。最后,我们使用 `DataLoader` 类来创建数据加载器,以便在训练模型时使用。
阅读全文