写一段keras迁移学习代码
时间: 2023-12-26 19:55:48 浏览: 26
下面是一个简单的Keras迁移学习代码示例,使用了ResNet50模型进行图像分类:
```
# 导入必要的模块
from keras.applications.resnet50 import ResNet50
from keras.layers import Dense, GlobalAveragePooling2D
from keras.models import Model
from keras.preprocessing.image import ImageDataGenerator
# 加载预训练模型
base_model = ResNet50(weights='imagenet', include_top=False)
# 添加新的全局平均池化层
x = base_model.output
x = GlobalAveragePooling2D()(x)
# 添加新的全连接层
x = Dense(1024, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax')(x)
# 构建新的模型
model = Model(inputs=base_model.input, outputs=predictions)
# 冻结预训练模型的层
for layer in base_model.layers:
layer.trainable = False
# 编译模型
model.compile(optimizer='rmsprop', loss='categorical_crossentropy')
# 数据增强
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
test_datagen = ImageDataGenerator(rescale=1./255)
train_generator = train_datagen.flow_from_directory(
train_dir,
target_size=(224, 224),
batch_size=batch_size,
class_mode='categorical')
validation_generator = test_datagen.flow_from_directory(
validation_dir,
target_size=(224, 224),
batch_size=batch_size,
class_mode='categorical')
# 训练模型
model.fit_generator(
train_generator,
steps_per_epoch=nb_train_samples // batch_size,
epochs=epochs,
validation_data=validation_generator,
validation_steps=nb_validation_samples // batch_size)
# 保存模型
model.save('my_model.h5')
```
其中,`ResNet50`是使用ImageNet数据集预训练的模型,我们在其基础上添加了新的全局平均池化层和全连接层。通过 `trainable` 参数可以控制哪些层需要被训练,这里我们冻结了预训练模型的所有层,只训练新添加的层。然后使用 `ImageDataGenerator` 对数据进行数据增强和预处理,并使用 `fit_generator` 方法进行模型训练。最后,使用 `save` 方法保存模型。