用pytorch划分训练集和测试集
时间: 2023-09-05 20:09:19 浏览: 163
可以使用PyTorch的`torch.utils.data.random_split`函数来划分训练集和测试集。下面是一个示例代码:
```python
import torch
from torch.utils.data import Dataset, DataLoader
from torch.utils.data.dataset import random_split
# 假设你的数据集是一个自定义的PyTorch Dataset对象
dataset = MyDataset()
# 定义训练集和测试集的比例
train_ratio = 0.8 # 训练集占总数据集的比例
test_ratio = 1 - train_ratio # 测试集占总数据集的比例
# 计算划分的样本数量
train_size = int(train_ratio * len(dataset))
test_size = len(dataset) - train_size
# 使用random_split函数划分训练集和测试集
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])
# 创建DataLoader对象来加载数据
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
```
在这个示例中,首先创建了一个自定义的`MyDataset`对象来表示你的数据集。然后,定义了训练集和测试集所占的比例。接下来,根据比例计算训练集和测试集的样本数量。最后,使用`random_split`函数将数据集划分为训练集和测试集,并创建相应的DataLoader对象来加载数据。
请注意,这只是一个示例代码,你需要根据你的具体数据集和需求进行相应的修改。
阅读全文