pytorch 加载文件夹中各类别的图片并对数据进行分割
时间: 2023-05-21 13:02:49 浏览: 131
可以使用 PyTorch 中的 DataLoader 和 torchvision 中的 ImageFolder 来加载文件夹中的各类别图片,并对数据进行分割。具体实现可以参考以下代码:
```python
import torch
from torchvision import transforms, datasets
# 定义数据预处理方式
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 加载数据集
data_dir = 'path/to/dataset'
dataset = datasets.ImageFolder(data_dir, transform=transform)
# 划分训练集和测试集
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])
# 定义 DataLoader
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=False)
```
这段代码可以将指定路径下的图片按照类别进行分类,并将其分为训练集和测试集,同时进行数据预处理和批量加载。
阅读全文