faster rcnn训练自己数据集的代码
时间: 2023-06-20 14:04:41 浏览: 208
Faster R-CNN 是一种基于深度学习的目标检测算法,它可以用于训练自己的数据集。下面是 Faster R-CNN 训练自己数据集的代码示例:
1. 准备训练数据集
首先需要准备训练数据集,包括图像和标注文件。标注文件可以是 VOC 格式或 COCO 格式。
2. 安装依赖库和下载代码
需要安装 TensorFlow 和 Keras,以及下载 Faster R-CNN 的代码。
3. 修改配置文件
修改 Faster R-CNN 的配置文件,包括训练和测试的参数、数据集路径以及模型保存路径等。
4. 训练模型
5. 测试模型
6. 模型优化
以下是 Faster R-CNN 训练自己数据集的代码示例。这里以 TensorFlow 和 Keras 为例,代码中的数据集为 VOC 格式。
# 导入依赖库
import tensorflow as tf
from keras import backend as K
from keras.layers import Input
from keras.models import Model
from keras.optimizers import Adam
from keras.utils import plot_model
from keras.callbacks import TensorBoard, ModelCheckpoint
from keras_frcnn import config
from keras_frcnn import data_generators
from keras_frcnn import losses as losses_fn
from keras_frcnn import roi_helpers
from keras_frcnn import resnet as nn
from keras_frcnn import visualize
# 设置配置文件
config_output_filename = 'config.pickle'
network = 'resnet50'
num_epochs = 1000
output_weight_path = './model_frcnn.hdf5'
input_weight_path = './resnet50_weights_tf_dim_ordering_tf_kernels.h5'
tensorboard_dir = './logs'
train_path = './train.txt'
test_path = './test.txt'
num_rois = 32
horizontal_flips = True
vertical_flips = True
rot_90 = True
output_weight_path = './model_frcnn.hdf5'
# 加载配置文件
config = config.Config()
config_output_filename = 'config.pickle'
# 加载数据集
all_imgs, classes_count, class_mapping = data_generators.get_data(train_path)
test_imgs, _, _ = data_generators.get_data(test_path)
# 计算平均像素值
if 'bg' not in classes_count:
classes_count['bg'] = 0
class_mapping['bg'] = len(class_mapping)
config.class_mapping = class_mapping
# 计算平均像素值
C = config.num_channels
mean_pixel = [103.939, 116.779, 123.68]
img_size = (config.im_size, config.im_size)
# 组装模型
input_shape_img = (None, None, C)
img_input = Input(shape=input_shape_img)
roi_input = Input(shape=(num_rois, 4))
shared_layers = nn.nn_base(img_input, trainable=True)
# RPN 网络
num_anchors = len(config.anchor_box_scales) * len(config.anchor_box_ratios)
rpn_layers = nn.rpn(shared_layers, num_anchors)
# RoI 网络
classifier = nn.classifier(shared_layers, roi_input, num_rois, nb_classes=len(classes_count), trainable=True)
model_rpn = Model(img_input, rpn_layers)
model_classifier = Model([img_input, roi_input], classifier)
# 加载权重
model_rpn.load_weights(input_weight_path, by_name=True)
model_classifier.load_weights(input_weight_path, by_name=True)
# 生成训练数据
data_gen_train = data_generators.get_anchor_gt(all_imgs, classes_count, C, K.image_dim_ordering(), mode='train', \
img_size=img_size, \
num_rois=num_rois, \
horizontal_flips=horizontal_flips, \
vertical_flips=vertical_flips, \
# 编译模型
optimizer = Adam(lr=1e-5)
model_rpn.compile(optimizer=optimizer, loss=[losses_fn.rpn_loss_cls(num_anchors), losses_fn.rpn_loss_regr(num_anchors)])
model_classifier.compile(optimizer=optimizer, loss=[losses_fn.class_loss_cls, losses_fn.class_loss_regr(len(classes_count) - 1)], metrics={'dense_class_{}'.format(len(classes_count)): 'accuracy'})
# 训练模型
epoch_length = 1000
num_epochs = int(num_epochs)
iter_num = 0
losses = np.zeros((epoch_length, 5))
rpn_accuracy_rpn_monitor = []
rpn_accuracy_for_epoch = []
start_time = time.time()
best_loss = np.Inf
class_mapping_inv = {v: k for k, v in class_mapping.items()}
print('Starting training')
for epoch_num in range(num_epochs):
progbar = generic_utils.Progbar(epoch_length)
print('Epoch {}/{}'.format(epoch_num + 1, num_epochs))
while True:
if len(rpn_accuracy_rpn_monitor) == epoch_length and C.verbose:
mean_overlapping_bboxes = float(sum(rpn_accuracy_rpn_monitor)) / len(rpn_accuracy_rpn_monitor)
rpn_accuracy_rpn_monitor = []
print('Average number of overlapping bounding boxes from RPN = {} for {} previous iterations'.format(mean_overlapping_bboxes, epoch_length))
if mean_overlapping_bboxes == 0:
print('RPN is not producing bounding boxes that overlap the ground truth boxes. Check RPN settings or keep training.')
X, Y, img_data = next(data_gen_train)
loss_rpn = model_rpn.train_on_batch(X, Y)
P_rpn = model_rpn.predict_on_batch(X)
R = roi_helpers.rpn_to_roi(P_rpn[0], P_rpn[1], C.image_dim_ordering(), use_regr=True, overlap_thresh=0.7, max_boxes=300)
X2, Y1, Y2, IouS = roi_helpers.calc_iou(R, img_data, C, class_mapping)
if X2 is None:
# sampling positive/negative samples
neg_samples = np.where(Y1[0, :, -1] == 1)
pos_samples = np.where(Y1[0, :, -1] == 0)
if len(neg_samples) > 0:
neg_samples = neg_samples[0]
neg_samples = []
if len(pos_samples) > 0:
pos_samples = pos_samples[0]
pos_samples = []
if C.num_rois > 1:
if len(pos_samples) < C.num_rois // 2:
selected_pos_samples = pos_samples.tolist()
selected_pos_samples = np.random.choice(pos_samples, C.num_rois // 2, replace=False).tolist()
selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=False).tolist()
selected_neg_samples = np.random.choice(neg_samples, C.num_rois - len(selected_pos_samples), replace=True).tolist()
sel_samples = selected_pos_samples + selected_neg_samples
# in the extreme case where num_rois = 1, we pick a random pos or neg sample
selected_pos_samples = pos_samples.tolist()
selected_neg_samples = neg_samples.tolist()
if np.random.randint(0, 2):
sel_samples = random.choice(neg_samples)
sel_samples = random.choice(pos_samples)
loss_class = model_classifier.train_on_batch([X, X2[:, sel_samples, :]], [Y1[:, sel_samples, :], Y2[:, sel_samples, :]])
losses[iter_num, 0] = loss_rpn[1]
losses[iter_num, 1] = loss_rpn[2]
losses[iter_num, 2] = loss_class[1]
losses[iter_num, 3] = loss_class[2]
losses[iter_num, 4] = loss_class[3]
iter_num += 1
progbar.update(iter_num, [('rpn_cls', np.mean(losses[:iter_num, 0])), ('rpn_regr', np.mean(losses[:iter_num, 1])),
('detector_cls', np.mean(losses[:iter_num, 2])),
('detector_regr', np.mean(losses[:iter_num, 3])),
('mean_overlapping_bboxes', float(sum(rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch))])
if iter_num == epoch_length:
loss_rpn_cls = np.mean(losses[:, 0])
loss_rpn_regr = np.mean(losses[:, 1])
loss_class_cls = np.mean(losses[:, 2])
loss_class_regr = np.mean(losses[:, 3])
class_acc = np.mean(losses[:, 4])
mean_overlapping_bboxes = float(sum(rpn_accuracy_for_epoch)) / len(rpn_accuracy_for_epoch)
rpn_accuracy_for_epoch = []
if C.verbose:
print('Mean number of bounding boxes from RPN overlapping ground truth boxes: {}'.format(mean_overlapping_bboxes))
print('Classifier accuracy for bounding boxes from RPN: {}'.format(class_acc))
print('Loss RPN classifier: {}'.format(loss_rpn_cls))
print('Loss RPN regression: {}'.format(loss_rpn_regr))
print('Loss Detector classifier: {}'.format(loss_class_cls))
print('Loss Detector regression: {}'.format(loss_class_regr))
print('Elapsed time: {}'.format(time.time() - start_time))
curr_loss = loss_rpn_cls + loss_rpn_regr + loss_class_cls + loss_class_regr
iter_num = 0
start_time = time.time()
if curr_loss < best_loss:
if C.verbose:
print('Total loss decreased from {} to {}, saving weights'.format(best_loss, curr_loss))
best_loss = curr_loss
except Exception as e:
print('Exception: {}'.format(e))
print('Training complete, exiting.')
这是一个简单的 Faster R-CNN 训练自己数据集的示例代码,可以根据自己的数据集和需求进行修改和优化。