使用ImageFolder从数据集中读取一定比例的训练集和测试集数据
时间: 2023-07-09 09:04:36 浏览: 161
可以使用PyTorch中的`random_split`函数来实现从数据集中读取一定比例的训练集和测试集数据。
下面是一个示例代码:
```python
import torch
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
# 定义数据预处理的转换方式
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 读取数据集
dataset = ImageFolder('path/to/dataset', transform=transform)
# 定义训练集和测试集的比例
train_ratio = 0.8
test_ratio = 0.2
# 计算训练集和测试集的数据数量
num_data = len(dataset)
num_train_data = int(num_data * train_ratio)
num_test_data = num_data - num_train_data
# 使用random_split函数将数据集划分为训练集和测试集
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [num_train_data, num_test_data])
# 打印训练集和测试集的数据数量
print(f"Number of training data: {len(train_dataset)}")
print(f"Number of test data: {len(test_dataset)}")
```
在上面的示例代码中,首先定义了数据预处理的转换方式`transform`,然后使用`ImageFolder`类读取数据集。接下来,通过定义训练集和测试集的比例,计算出训练集和测试集的数据数量。最后,使用`random_split`函数将数据集划分为训练集和测试集,并打印出训练集和测试集的数据数量。
阅读全文