数据增强pytorch代码
时间: 2023-07-28 15:09:48 浏览: 100
以下是一个基本的数据增强的 PyTorch 代码示例:
```python
import torch
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
class MyDataset(Dataset):
def __init__(self, x, y, transform=None):
self.x = x
self.y = y
self.transform = transform
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
x = self.x[idx]
y = self.y[idx]
if self.transform:
x = self.transform(x)
return x, y
# 定义数据增强操作
transform = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
transforms.ToTensor()
])
# 加载数据集
x_train = torch.randn(100, 3, 32, 32)
y_train = torch.randint(0, 10, (100,))
train_dataset = MyDataset(x_train, y_train, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=10, shuffle=True)
# 训练模型
for epoch in range(10):
for x_batch, y_batch in train_loader:
# 在这里训练模型
pass
```
在这个示例中,我们定义了一个自定义的数据集类 `MyDataset`,其中包含输入数据 `x` 和标签数据 `y`,以及一个可选的数据增强操作 `transform`。我们使用 PyTorch 内置的一些数据增强操作,包括随机水平翻转、随机旋转、颜色抖动和将图像转换为张量。
然后,我们将数据集加载到 `DataLoader` 中,设置批量大小为 10,并在每个训练周期中对数据进行随机重排。最后,我们可以使用训练循环来训练我们的模型。
阅读全文