VGG网络进行图像分类Python 代码
时间: 2024-11-25 16:16:43 浏览: 4
VGG网络是一种经典的深度卷积神经网络,由牛津大学的研究团队开发,主要用于图像识别和计算机视觉任务。以下是一个简单的使用Keras库在Python中实现VGG16预训练模型进行图像分类的例子:
```python
from keras.applications.vgg16 import VGG16, preprocess_input
from keras.preprocessing.image import ImageDataGenerator
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D
import numpy as np
# 加载预训练的VGG16模型,不包括顶部的全连接层
vgg = VGG16(weights='imagenet', include_top=False, input_shape=(224, 224, 3))
# 冻结所有层以便微调
for layer in vgg.layers:
layer.trainable = False
# 添加全局平均池化层和新全连接层
x = vgg.output
x = GlobalAveragePooling2D()(x)
x = Dense(1024, activation='relu')(x) # 这里假设我们要分类1000种物体
predictions = Dense(num_classes, activation='softmax')(x)
# 构建新的模型,包含原VGG16的特征提取部分和我们添加的新部分
model = Model(inputs=vgg.input, outputs=predictions)
# 数据增强,例如随机翻转、裁剪等
datagen = ImageDataGenerator(preprocessing_function=preprocess_input, rotation_range=20, width_shift_range=0.15,
height_shift_range=0.15)
# 编译模型
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])
# 训练模型
train_generator = datagen.flow_from_directory(train_dir, target_size=(224, 224), batch_size=batch_size, class_mode='categorical')
model.fit_generator(train_generator, steps_per_epoch=len(train_generator), epochs=num_epochs)
# 保存模型
model.save('my_vgg_model.h5')
阅读全文