给出数据增强的pytorch代码
时间: 2023-09-20 08:11:43 浏览: 70
下面是一些常用的数据增强方法的 PyTorch 代码示例:
1. 随机裁剪:
```python
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.RandomCrop(size=224),
])
```
2. 随机旋转:
```python
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.RandomRotation(degrees=30),
])
```
3. 随机缩放:
```python
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.RandomResizedCrop(size=224),
])
```
4. 随机翻转:
```python
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
])
```
5. 随机亮度、对比度、饱和度调整:
```python
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
])
```
6. 随机噪声:
```python
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.RandomApply([transforms.GaussianBlur(kernel_size=3)], p=0.5),
])
```
7. 随机变形:
```python
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.RandomApply([transforms.RandomAffine(degrees=30)], p=0.5),
])
```
8. 随机颜色变换:
```python
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.ColorJitter(hue=0.5),
])
```
9. Mixup:
```python
import random
import torch
import numpy as np
def mixup_data(x, y, alpha=1.0):
if alpha > 0:
lam = np.random.beta(alpha, alpha)
else:
lam = 1
batch_size = x.size()[0]
index = torch.randperm(batch_size)
mixed_x = lam * x + (1 - lam) * x[index, :]
y_a, y_b = y, y[index]
return mixed_x, y_a, y_b, lam
```
10. Cutout:
```python
import random
import numpy as np
import torch
def cutout(image, n_holes=1, length=16):
h, w = image.shape[1], image.shape[2]
mask = np.ones((h, w), np.float32)
for n in range(n_holes):
y = np.random.randint(h)
x = np.random.randint(w)
y1 = np.clip(y - length // 2, 0, h)
y2 = np.clip(y + length // 2, 0, h)
x1 = np.clip(x - length // 2, 0, w)
x2 = np.clip(x + length // 2, 0, w)
mask[y1:y2, x1:x2] = 0.
mask = torch.from_numpy(mask)
mask = mask.expand_as(image)
image = image * mask
return image
```
这些代码示例可以在 PyTorch 中方便地进行数据增强,提高模型的泛化能力。