pytorch python 数据增强代码
时间: 2023-07-27 20:06:34 浏览: 110
下面是一个使用PyTorch和Python进行数据增强的示例代码:
```python
import torch
import torchvision.transforms as transforms
from PIL import Image
# 定义数据增强的transform
transform = transforms.Compose([
transforms.RandomCrop(224), # 随机裁剪为224x224大小
transforms.RandomHorizontalFlip(), # 随机水平翻转
transforms.RandomRotation(30), # 随机旋转±30度
transforms.ToTensor(), # 转为Tensor
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) # 归一化
])
# 加载图像
image = Image.open('image.jpg')
# 应用数据增强的transform
augmented_image = transform(image)
# 显示增强后的图像
print(augmented_image.shape)
```
在上面的示例中,我们使用了一系列的transforms来定义数据增强的操作,包括随机裁剪、随机翻转、随机旋转、转换为Tensor和归一化。然后,我们加载了一张图像,将其应用于数据增强的transform,最后打印出增强后的图像的形状。
你可以根据自己的需求调整transforms中的操作和参数,以实现不同的数据增强效果。
阅读全文