请生成一段可执行的python代码,实现:应用cnn方法对人脸图像提取特征,根据提取结果构建一个机器学习模型对人脸图像进行分类,达成识别性别的目的
时间: 2024-10-10 11:07:46 浏览: 39
首先,我们需要导入所需的库,如TensorFlow、Keras用于深度学习,以及OpenCV用于预处理图像数据。以下是一个简单的例子,展示如何使用CNN(卷积神经网络)提取人脸特征并训练一个模型进行性别识别:
```python
import tensorflow as tf
from tensorflow.keras.applications import VGG16
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Flatten
from tensorflow.keras.models import Sequential
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from keras_face_recognition import face_detection
# 预设参数
img_width, img_height = 224, 224
batch_size = 32
epochs = 50
num_classes = 2 # 二分类任务,男性和女性
# 加载预训练的VGG16模型,并去掉顶层
base_model = VGG16(weights='imagenet', include_top=False, input_shape=(img_width, img_height, 3))
# 添加新的全连接层
x = base_model.output
x = Flatten()(x)
x = Dense(256, activation='relu')(x) # 输出维度调整为256维
predictions = Dense(num_classes, activation='softmax')(x)
# 构建完整的模型
model = Model(inputs=base_model.input, outputs=predictions)
# 冻结预训练模型的权重
for layer in base_model.layers:
layer.trainable = False
# 编译模型
model.compile(optimizer=tf.keras.optimizers.Adam(), loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# 数据增强
train_datagen = ImageDataGenerator(rescale=1./255,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True)
# 获取人脸图像路径和对应的标签
train_generator = train_datagen.flow_from_directory('path_to_train_faces',
target_size=(img_height, img_width),
batch_size=batch_size,
class_mode='binary')
# 训练模型
history = model.fit(train_generator,
steps_per_epoch=len(train_generator),
epochs=epochs,
validation_data=train_generator,
validation_steps=len(train_generator))
# 保存模型
model.save('gender_recognition_model.h5')
# 使用模型进行预测
def predict_gender(image_path):
image = face_detection.detect_face_and_align(image_path, (img_width, img_height))
predictions = model.predict(image.reshape(-1, img_width, img_height, 3))
return tf.argmax(predictions[0])
# 示例预测
image_path = 'path_to_test_image.jpg'
predicted_gender = predict_gender(image_path)
print(f"Predicted gender: {['Male', 'Female'][predicted_gender]}")
阅读全文