深度学习数据增强代码
时间: 2023-08-14 14:06:23 浏览: 39
深度学习数据增强是一种常用的技术,它可以通过对原始数据进行变换、旋转、缩放等操作来生成更多的训练数据,从而提高模型的泛化能力和鲁棒性。
以下是一个基于Keras实现的简单数据增强代码示例:
```python
from keras.preprocessing.image import ImageDataGenerator
# 创建ImageDataGenerator对象
datagen = ImageDataGenerator(
rotation_range=20, # 旋转范围
width_shift_range=0.1, # 水平平移范围
height_shift_range=0.1, # 垂直平移范围
shear_range=0.2, # 错切变换范围
zoom_range=0.2, # 缩放范围
horizontal_flip=True, # 是否进行水平翻转
fill_mode='nearest' # 填充方式
)
# 加载原始数据
train_data = ...
train_labels = ...
# 对原始数据进行增强
datagen.fit(train_data)
augmented_data = datagen.flow(train_data, train_labels, batch_size=batch_size)
# 使用增强后的数据进行训练
model.fit(augmented_data, epochs=epochs, steps_per_epoch=steps_per_epoch)
```
在上面的代码中,我们使用Keras提供的`ImageDataGenerator`类创建了一个数据增强对象,并设置了旋转、平移、错切、缩放、翻转等操作的范围。然后,我们将原始数据传入`datagen.fit()`方法中进行增强,生成增强后的数据集`augmented_data`。最后,我们使用增强后的数据集进行模型训练。
当然,上面的代码只是一个简单的示例,实际应用中可能需要根据具体任务的需求进行相应的参数设置和数据增强操作。