基于resnet50对图像能见度检测的代码
时间: 2023-12-15 21:58:37 浏览: 197
以下是基于ResNet50的图像能见度检测代码,使用Keras框架实现:
```python
import cv2
import numpy as np
from keras.models import Model
from keras.layers import Dense, GlobalAveragePooling2D
from keras.applications.resnet50 import ResNet50
from keras.preprocessing.image import ImageDataGenerator
# 加载ResNet50模型,去掉顶层
base_model = ResNet50(weights="imagenet", include_top=False, input_shape=(224, 224, 3))
# 添加全局平均池化层
x = base_model.output
x = GlobalAveragePooling2D()(x)
# 添加全连接层,输出为1个值
predictions = Dense(1, activation="sigmoid")(x)
# 构建模型
model = Model(inputs=base_model.input, outputs=predictions)
# 冻结ResNet50的前100层
for layer in model.layers[:100]:
layer.trainable = False
# 编译模型
model.compile(optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"])
# 创建ImageDataGenerator实例,用于数据增强和图像预处理
datagen = ImageDataGenerator(rescale=1./255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True)
# 加载训练集和验证集
train_set = datagen.flow_from_directory("train/", target_size=(224, 224), batch_size=32, class_mode="binary")
val_set = datagen.flow_from_directory("val/", target_size=(224, 224), batch_size=32, class_mode="binary")
# 训练模型
model.fit_generator(train_set, steps_per_epoch=len(train_set), epochs=10, validation_data=val_set, validation_steps=len(val_set))
# 加载测试集
test_set = datagen.flow_from_directory("test/", target_size=(224, 224), batch_size=32, class_mode="binary")
# 评估模型性能
scores = model.evaluate_generator(test_set, steps=len(test_set))
print("Accuracy: %.2f%%" % (scores[1]*100))
# 对单张图像进行预测
img = cv2.imread("test/image1.jpg")
img = cv2.resize(img, (224, 224))
img = np.expand_dims(img, axis=0)
img = img/255.0
pred = model.predict(img)
if pred > 0.5:
print("图像能见度不好")
else:
print("图像能见度良好")
```
需要注意的是,这里的训练集、验证集和测试集需要按照Keras的ImageDataGenerator要求的格式组织,具体可以参考Keras文档。同时,也可以根据实际情况对数据增强参数进行调整。
阅读全文