yolov8如何使用非极大值抑制,这是给出的非极大值抑制函数def non_max_suppression( prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, labels=(), max_det=300, n
时间: 2024-02-14 09:17:21 浏览: 172
非极大值抑制(Non-Maximum Suppression, NMS)是一种用于去除重叠检测框的算法。在YOLOv8中,NMS的实现可以参考以下的代码:
```python
def non_max_suppression(prediction, conf_thres=0.25, iou_thres=0.45, classes=None, agnostic=False, multi_label=False, labels=(), max_det=300, nms_kind="greedy", beta_nms=0.6):
# prediction是网络输出,包含了检测框的位置和置信度等信息
# conf_thres是置信度的阈值,低于该阈值的检测框会被忽略
# iou_thres是IoU的阈值,重叠度高于该阈值的检测框会被合并
# classes是要保留的类别,如果为None,则保留所有类别
# agnostic表示是否忽略检测框的类别
# multi_label表示是否允许一个物体被多个框检测到
# labels是给定的标签列表,只有这些标签的检测框会被保留
# max_det是最多保留的检测框数量
# nms_kind表示采用哪种NMS算法,可以是"greedy"或"soft"
# beta_nms是软NMS算法中的参数
# ...
# 对每个图像进行处理,假设prediction的shape为(batch_size, num_anchors, num_classes+5)
for i, (pred, im_labels, _) in enumerate(zip(prediction, labels, image_sizes)):
# ...
# 获取置信度大于阈值的检测框
pred = pred[pred[:, 4] > conf_thres]
# 如果没有符合要求的检测框,则跳过
if not pred.size(0):
continue
# 根据置信度从大到小排序
pred = pred[(-pred[:, 4]).argsort()]
# 如果指定了类别,则只保留该类别的检测框
if classes is not None:
pred = pred[pred[:, 5].long() == classes]
# 如果指定了标签,则只保留包含该标签的检测框
if len(im_labels):
pred = pred[np.array([all(x in p for x in im_labels) for p in pred[:, -1]])]
# 如果没有符合要求的检测框,则跳过
if not pred.size(0):
continue
# 获取检测框的坐标和置信度
pred_boxes = pred[:, :4]
pred_scores = pred[:, 4]
# 计算检测框的面积
pred_areas = (pred_boxes[:, 2] - pred_boxes[:, 0]) * (pred_boxes[:, 3] - pred_boxes[:, 1])
# 初始化被保留的检测框
keep = []
# 采用不同的NMS算法进行处理
if nms_kind == "greedy":
# 采用贪心NMS算法
while pred.size(0):
# 保留置信度最大的检测框
i = pred_scores.argmax()
keep.append(i)
# 计算当前检测框与其他检测框的IoU
ious = bbox_iou(pred_boxes[i:i+1], pred_boxes)
# 找出IoU大于阈值的检测框,删除它们
idx = (ious > iou_thres).nonzero(as_tuple=False).squeeze()
pred = np.delete(pred, idx, axis=0)
# 删除已经处理过的检测框
mask = np.array([i not in idx for i in range(len(pred))])
pred_boxes = pred_boxes[mask]
pred_scores = pred_scores[mask]
pred_areas = pred_areas[mask]
elif nms_kind == "soft":
# 采用软NMS算法
while pred.size(0):
# 保留置信度最大的检测框
i = pred_scores.argmax()
keep.append(i)
# 计算当前检测框与其他检测框的IoU
ious = bbox_iou(pred_boxes[i:i+1], pred_boxes)
# 采用指数函数进行权重调整
weights = np.exp(-(ious * ious) / beta_nms)
# 更新置信度
pred_scores *= weights
# 删除置信度小于阈值的检测框
idx = (pred_scores < conf_thres).nonzero(as_tuple=False).squeeze()
pred = np.delete(pred, idx, axis=0)
# 删除已经处理过的检测框
mask = np.array([i not in idx for i in range(len(pred))])
pred_boxes = pred_boxes[mask]
pred_scores = pred_scores[mask]
pred_areas = pred_areas[mask]
# 只保留前max_det个检测框
keep = keep[:max_det]
# 更新预测结果
prediction[i] = prediction[i][keep]
return prediction
```
该函数的输入参数包括网络输出(prediction)、置信度阈值(conf_thres)、IoU阈值(iou_thres)、要保留的类别(classes)、是否忽略检测框的类别(agnostic)、是否允许一个物体被多个框检测到(multi_label)、给定的标签列表(labels)、最多保留的检测框数量(max_det)、采用哪种NMS算法(nms_kind)和软NMS算法中的参数(beta_nms)等。在函数内部,首先根据置信度阈值和类别过滤掉一部分检测框,然后按照置信度从大到小排序。接着,分别采用贪心NMS算法和软NMS算法对每个图像的检测框进行处理,最终得到被保留的检测框。最后,将保留下来的检测框更新到预测结果中,返回更新后的预测结果。
阅读全文