pytorch 数据集划分
时间: 2023-11-07 18:59:39 浏览: 84
基于pytorch的分类网络
5星 · 资源好评率100%
在PyTorch中,可以使用torch.utils.data.Dataset和torch.utils.data.DataLoader来加载和划分数据集。
1. Dataset:用于定义如何加载数据集,需要实现__len__和__getitem__方法。__len__方法返回数据集的长度,__getitem__方法返回给定索引的数据样本。
2. DataLoader:用于将数据集划分成批次,并提供数据增强和并行加载等功能。可以指定批次大小、是否打乱数据、是否使用多线程等参数。
以下是一个简单的数据集划分的示例:
```python
import torch
from torch.utils.data import Dataset, DataLoader
# 定义数据集类
class MyDataset(Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, index):
return self.data[index]
# 构造数据集
data = [i for i in range(10)]
dataset = MyDataset(data)
# 划分数据集
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
# 创建DataLoader
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)
# 使用DataLoader加载数据
for batch_idx, batch_data in enumerate(train_loader):
print(batch_data)
for batch_idx, batch_data in enumerate(test_loader):
print(batch_data)
```
上述代码中,首先定义了一个数据集类MyDataset,然后通过random_split方法将数据集划分为训练集和测试集,最后使用DataLoader加载数据。在训练集中,每次返回两个样本;在测试集中,不打乱数据,每次返回两个样本。
阅读全文