如何将训练集中猫和狗的照片分开,要pytorch代码
时间: 2023-11-22 18:56:20 浏览: 81
基于pytorch + CNN的猫狗图像识别源码+数据集+训练好的模型文件
5星 · 资源好评率100%
可以使用PyTorch中的`torchvision.datasets.ImageFolder`来读取图片数据集,并且可以自动按照文件夹名字将数据集分成不同类别。
下面是一个简单的代码示例:
```python
import torch
import torchvision
import torchvision.transforms as transforms
# 定义数据增强和预处理
transform = transforms.Compose(
[transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
# 加载数据集
train_dataset = torchvision.datasets.ImageFolder(root='/path/to/training/folder', transform=transform)
# 将数据集分成猫和狗两类
cat_dataset = torch.utils.data.Subset(train_dataset, train_dataset.targets == 0)
dog_dataset = torch.utils.data.Subset(train_dataset, train_dataset.targets == 1)
# 打印数据集大小
print(f"Cat dataset size: {len(cat_dataset)}")
print(f"Dog dataset size: {len(dog_dataset)}")
```
代码中,首先定义了数据增强和预处理的内容,例如图片的大小调整、中心裁剪、转换为张量,以及归一化。然后使用`ImageFolder`读取数据集,将数据集根据文件夹名称分成不同的类别。最后使用`Subset`将数据集进一步划分为猫和狗两类,并且打印数据集大小。
阅读全文