如何用pytorch将数据集中随机分为训练集和测试集
时间: 2024-04-10 16:28:24 浏览: 162
在使用PyTorch进行数据集的随机分割时,可以借助`torch.utils.data.random_split`函数来实现。这个函数可以将数据集随机分割为训练集和测试集,可以根据指定的比例或长度来划分数据集。
下面是一个示例代码,展示了如何使用PyTorch将数据集随机分割为训练集和测试集:
```python
import torch
from torch.utils.data import DataLoader, SubsetRandomSampler
from torchvision import datasets, transforms
# 设置随机种子
torch.manual_seed(42)
# 定义数据变换
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
# 加载数据集
dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# 定义划分比例
train_ratio = 0.8 # 训练集占总数据集的比例
# 计算划分的长度
train_size = int(train_ratio * len(dataset))
test_size = len(dataset) - train_size
# 划分数据集
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
# 创建数据加载器
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=True)
# 打印划分后的数据集大小
print(f"训练集大小: {len(train_dataset)}")
print(f"测试集大小: {len(test_dataset)}")
```
在上述代码中,首先定义了数据变换,然后使用`datasets.MNIST`加载MNIST数据集。接着,通过`torch.utils.data.random_split`函数将数据集随机划分为训练集和测试集。最后,使用`DataLoader`创建数据加载器,并打印出划分后的数据集大小。
这样,你就可以使用PyTorch将数据集随机分割为训练集和测试集了。
阅读全文