使用训练数据对InceptionV3模型进行训练代码
时间: 2024-05-10 15:19:15 浏览: 85
Python-使用预训练好的InceptionV3模型对自己的数据进行分类
以下是使用Keras库对InceptionV3模型进行训练的示例代码:
```python
from keras.applications.inception_v3 import InceptionV3
from keras.preprocessing.image import ImageDataGenerator
from keras.layers import Dense, GlobalAveragePooling2D
from keras.models import Model
from keras.optimizers import SGD
# 定义数据增强
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
# 创建训练集和验证集的数据生成器
train_generator = train_datagen.flow_from_directory(
'train/',
target_size=(299, 299),
batch_size=32,
class_mode='categorical')
validation_generator = test_datagen.flow_from_directory(
'validation/',
target_size=(299, 299),
batch_size=32,
class_mode='categorical')
# 加载预训练的InceptionV3模型
base_model = InceptionV3(weights='imagenet', include_top=False)
# 添加自定义全局平均池化层和全连接层
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(10, activation='softmax')(x)
# 构建完整的模型
model = Model(inputs=base_model.input, outputs=predictions)
# 冻结预训练模型的所有层
for layer in base_model.layers:
layer.trainable = False
# 编译模型
model.compile(optimizer=SGD(lr=0.001, momentum=0.9), loss='categorical_crossentropy', metrics=['accuracy'])
# 训练模型
model.fit_generator(
train_generator,
steps_per_epoch=2000,
epochs=50,
validation_data=validation_generator,
validation_steps=800)
# 保存模型
model.save('inceptionv3_trained.h5')
```
在这个示例中,我们使用ImageDataGenerator来对训练数据进行数据增强,提高模型的泛化能力。然后,我们加载预训练的InceptionV3模型,添加自定义全局平均池化层和全连接层,并构建完整的模型。接着,我们冻结预训练模型的所有层,并使用SGD优化器编译模型。最后,我们使用fit_generator方法来训练模型,并使用save方法保存模型。
阅读全文