torch transform
时间: 2023-07-03 14:09:00 浏览: 43
`torch.transforms`是PyTorch中的一个模块,用于数据的预处理和增强。它包含了各种常用的数据变换方法,如缩放、裁剪、翻转、填充、归一化等。这些变换可以被串联成一个变换序列,通过调用`transforms.Compose()`来实现。同时,`transforms`模块也提供了一些自定义的变换方法,可以根据实际需求进行拓展和使用。
常用的一些数据变换方法包括:
- `transforms.Resize()`:调整图片尺寸大小;
- `transforms.CenterCrop()`:中心裁剪图片;
- `transforms.RandomCrop()`:随机裁剪图片;
- `transforms.RandomHorizontalFlip()`:随机水平翻转图片;
- `transforms.ToTensor()`:将图片转换为Tensor类型;
- `transforms.Normalize()`:数据归一化处理。
例如,可以通过以下代码创建一个变换序列,将输入的图片进行随机裁剪、随机翻转、转换为Tensor类型,并进行数据归一化处理:
```
transform = transforms.Compose([
transforms.RandomCrop(32),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
```
然后可以将该变换序列应用于输入数据,例如:
```
dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
```