VGG16迁移学习 三分类 代码
时间: 2023-09-11 07:07:23 浏览: 130
以下是使用VGG16进行迁移学习的三分类代码示例:
```python
from keras.applications.vgg16 import VGG16
from keras.preprocessing import image
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D
from keras import backend as K
from keras.preprocessing.image import ImageDataGenerator
# 设置图像尺寸和类别数
img_width, img_height = 224, 224
num_classes = 3
# 创建VGG16模型
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(img_width, img_height, 3))
# 添加自定义层
x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x)
predictions = Dense(num_classes, activation='softmax')(x)
# 构建完整模型
model = Model(inputs=base_model.input, outputs=predictions)
# 锁定VGG16预训练模型的所有层
for layer in base_model.layers:
layer.trainable = False
# 编译模型
model.compile(optimizer='rmsprop', loss='categorical_crossentropy', metrics=['accuracy'])
# 设置训练和验证数据的路径
train_data_dir = 'path/to/train/data'
valid_data_dir = 'path/to/valid/data'
# 设置数据增强参数
train_datagen = ImageDataGenerator(
rescale=1./255,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
valid_datagen = ImageDataGenerator(rescale=1./255)
# 加载训练和验证数据
train_generator = train_datagen.flow_from_directory(
train_data_dir,
target_size=(img_width, img_height),
batch_size=32,
class_mode='categorical')
valid_generator = valid_datagen.flow_from_directory(
valid_data_dir,
target_size=(img_width, img_height),
batch_size=32,
class_mode='categorical')
# 训练模型
model.fit_generator(
train_generator,
steps_per_epoch=train_generator.samples // train_generator.batch_size,
epochs=10,
validation_data=valid_generator,
validation_steps=valid_generator.samples // valid_generator.batch_size)
# 保存模型
model.save('my_model.h5')
```
在上述代码中,我们使用VGG16作为基础模型,并添加了自定义全局平均池化层和全连接层,用于进行三分类任务。我们还锁定了VGG16预训练模型的所有层,只训练我们添加的自定义层。最后,我们使用ImageDataGenerator来加载和增强数据,并使用fit_generator函数训练模型。
阅读全文