迁移学习特征提取网络共享特征的代码
时间: 2024-08-28 07:02:34 浏览: 33
迁移学习的工具箱代码
4星 · 用户满意度95%
迁移学习是一种深度学习技术,其中预训练模型的特征提取部分被保留并用于新任务的模型初始化。在Python中,特别是在使用Keras或PyTorch这样的库时,可以这样做:
```python
# 使用Keras示例(假设我们有VGG16模型)
from keras.applications.vgg16 import VGG16, preprocess_input
from keras.layers import Flatten, Dense
# 加载预训练的VGG16模型(不包含顶部全连接层)
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
# 将所有卷积层设置为不可训练
for layer in base_model.layers:
layer.trainable = False
# 添加新的自定义全连接层(例如,512个节点)
x = base_model.output
x = Flatten()(x)
x = Dense(512, activation='relu')(x)
# 创建完整的模型,包括预训练特征提取和新添加的分类层
model = Model(inputs=base_model.input, outputs=x)
# 如果需要微调,可以对顶部的一些层进行训练
# model.layers[-n:].trainable = True # 其中n是你想要调整的层数
# 编译模型(记得设定适当的损失函数、优化器和评估指标)
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
# 进行特征提取和训练
```
这里的`Flatten`层将卷积后的二维张量展平成一维向量,`Dense`层则是添加新的全连接层。在实际应用中,你需要替换输入形状和自定义的全连接结构。
阅读全文