Keras中的ImageDataGenerator类
时间: 2024-05-08 09:21:15 浏览: 68
使用Keras 的ImageDataGenerator类实现批量数据增强
ImageDataGenerator是Keras中一个非常方便的图像数据生成器,主要用于数据增强(data augmentation)和实时数据扩充(real-time data augmentation)。它可以自动将一批原始图像转换为训练所需的随机数据,比如随机旋转、缩放、翻转等操作,从而扩大训练数据集,提高模型的泛化能力。
ImageDataGenerator类可以通过定义不同的参数来实现各种图像增强的操作,例如:
- rotation_range:旋转角度范围;
- width_shift_range、height_shift_range:图像水平、垂直方向的平移范围;
- shear_range:剪切强度;
- zoom_range:随机缩放范围;
- horizontal_flip、vertical_flip:是否随机水平、垂直翻转图像;
- fill_mode:填充模式。
使用ImageDataGenerator类时,需要先通过fit()方法计算出数据集的统计信息,然后可以通过flow()方法进行数据生成。例如:
```python
from keras.preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
datagen.fit(x_train)
model.fit_generator(datagen.flow(x_train, y_train, batch_size=32),
steps_per_epoch=len(x_train)/32, epochs=50)
```
这段代码中,我们先定义了一个ImageDataGenerator对象,然后定义了各种图像增强的参数。接着,我们通过fit()方法计算出数据集的统计信息,最后通过flow()方法生成扩充后的数据集,用于训练模型。
阅读全文