FastRCNN 实现对 VOCdevkit 文件夹内好坏香蕉目标识别+bounding box运行代码
时间: 2024-03-17 15:47:55 浏览: 61
tensorflow yolo.zip
以下是使用Fast R-CNN实现对VOCdevkit数据集中香蕉目标检测和bounding box的Python代码示例:
```python
import cv2
import numpy as np
import os
import xml.etree.ElementTree as ET
import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import models
from tensorflow.keras import optimizers
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from sklearn.model_selection import train_test_split
# 数据集路径
data_path = 'data/VOCdevkit/'
# 类别列表
classes = ['good_banana', 'bad_banana']
# 定义模型
def create_model():
base_model = models.Sequential()
base_model.add(layers.Conv2D(32, (3, 3), activation='relu', input_shape=(224, 224, 3)))
base_model.add(layers.MaxPooling2D((2, 2)))
base_model.add(layers.Conv2D(64, (3, 3), activation='relu'))
base_model.add(layers.MaxPooling2D((2, 2)))
base_model.add(layers.Conv2D(128, (3, 3), activation='relu'))
base_model.add(layers.MaxPooling2D((2, 2)))
base_model.add(layers.Flatten())
base_model.add(layers.Dense(512, activation='relu'))
base_model.add(layers.Dense(len(classes), activation='softmax'))
return base_model
# 加载数据集
def load_dataset():
images = []
labels = []
for cls in classes:
cls_path = os.path.join(data_path, 'JPEGImages', cls)
for img_name in os.listdir(cls_path):
img_path = os.path.join(cls_path, img_name)
img = cv2.imread(img_path)
img = cv2.resize(img, (224, 224))
img = img / 255.0
images.append(img)
label = np.zeros(len(classes))
label[classes.index(cls)] = 1.0
labels.append(label)
return np.array(images), np.array(labels)
# 加载bounding box
def load_bbox():
bbox = {}
for cls in classes:
cls_path = os.path.join(data_path, 'Annotations', cls)
for xml_name in os.listdir(cls_path):
xml_path = os.path.join(cls_path, xml_name)
tree = ET.parse(xml_path)
root = tree.getroot()
for obj in root.findall('object'):
name = obj.find('name').text
bbox_info = obj.find('bndbox')
xmin = int(bbox_info.find('xmin').text)
ymin = int(bbox_info.find('ymin').text)
xmax = int(bbox_info.find('xmax').text)
ymax = int(bbox_info.find('ymax').text)
bbox.setdefault(cls, []).append([xmin, ymin, xmax, ymax])
return bbox
# 训练模型
def train_model():
# 加载数据
images, labels = load_dataset()
bbox = load_bbox()
# 划分训练集和测试集
x_train, x_test, y_train, y_test = train_test_split(images, labels, test_size=0.2, random_state=42)
# 数据增强
datagen = ImageDataGenerator(
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
horizontal_flip=True,
zoom_range=0.2
)
# 定义模型
model = create_model()
# 编译模型
model.compile(loss='categorical_crossentropy', optimizer=optimizers.Adam(lr=1e-4), metrics=['acc'])
# 定义回调函数
filepath = 'model.h5'
checkpoint = ModelCheckpoint(filepath, monitor='val_acc', verbose=1, save_best_only=True, mode='max')
early_stop = EarlyStopping(monitor='val_acc', patience=5, mode='max')
# 训练模型
model.fit_generator(datagen.flow(x_train, y_train, batch_size=32),
steps_per_epoch=len(x_train) / 32,
epochs=50,
validation_data=(x_test, y_test),
callbacks=[checkpoint, early_stop])
# 保存模型
model.save('model_final.h5')
# 测试模型
def test_model():
# 加载模型
model = models.load_model('model_final.h5')
# 加载数据
images, labels = load_dataset()
bbox = load_bbox()
# 预测并绘制bounding box
for i in range(len(images)):
img = images[i]
label = labels[i]
cls = classes[np.argmax(label)]
# 预测
pred = model.predict(np.expand_dims(img, axis=0))[0]
# 绘制bounding box
if np.max(pred) > 0.5:
idx = np.argmax(pred)
x1, y1, x2, y2 = bbox[cls][i]
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
cv2.putText(img, classes[idx], (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 255, 0), 2)
# 显示图片
cv2.imshow('image', img)
if cv2.waitKey(0) & 0xff == ord('q'):
break
cv2.destroyAllWindows()
if __name__ == '__main__':
train_model()
test_model()
```
注意:在使用该代码之前,需要确保已经安装了必要的Python库,例如OpenCV、Pillow、TensorFlow等。
阅读全文