使用torchvision加载自己的数据
时间: 2024-05-07 22:15:10 浏览: 6
要使用torchvision加载自己的数据,需要按照以下步骤:
1. 准备数据
将自己的数据准备好,并按照一定的目录结构组织,例如:
```
data/
train/
class1/
img1.jpg
img2.jpg
...
class2/
img1.jpg
img2.jpg
...
...
val/
class1/
img1.jpg
img2.jpg
...
class2/
img1.jpg
img2.jpg
...
...
```
其中,train文件夹包含训练集数据,val文件夹包含验证集数据,每个类别的数据应该放在一个单独的文件夹中。
2. 创建数据集对象
使用torchvision.datasets.ImageFolder创建数据集对象,并指定数据目录和数据变换(可选):
```python
import torchvision.datasets as datasets
import torchvision.transforms as transforms
train_dir = "data/train"
val_dir = "data/val"
# 数据变换(可选)
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
# 创建数据集对象
train_dataset = datasets.ImageFolder(train_dir, transform=train_transform)
val_dataset = datasets.ImageFolder(val_dir, transform=val_transform)
```
3. 创建数据加载器
使用torch.utils.data.DataLoader创建数据加载器,并指定批次大小、是否打乱数据等参数:
```python
import torch.utils.data as data
batch_size = 32
# 创建数据加载器
train_loader = data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
```
现在,就可以使用train_loader和val_loader分别加载训练集数据和验证集数据了。