给出torchvision对T91数据集进行数据增强的代码
时间: 2024-05-03 11:18:06 浏览: 4
以下是使用torchvision对T91数据集进行数据增强的示例代码:
```python
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
train_transforms = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip(),
transforms.RandomRotation(45),
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
transforms.RandomResizedCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
train_dataset = ImageFolder('path/to/T91/data', transform=train_transforms)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
```
在上述代码中,我们使用了以下数据增强方法:
- `RandomHorizontalFlip`:以0.5的概率随机水平翻转图像;
- `RandomVerticalFlip`:以0.5的概率随机垂直翻转图像;
- `RandomRotation`:随机旋转图像一定角度,这里设置为45度;
- `ColorJitter`:随机改变图像的亮度、对比度、饱和度和色相;
- `RandomResizedCrop`:随机裁剪图像,并将裁剪后的图像缩放到指定大小,这里设置为224;
- `ToTensor`:将图像转换为张量;
- `Normalize`:对图像进行归一化,使其像素值在[-1, 1]之间。
上述代码中,我们使用了`ImageFolder`类加载T91数据集,并将其传递给`DataLoader`类以生成批次数据。