Yolov3中使用的nms函数如何使用
时间: 2023-10-26 10:07:41 浏览: 87
在 YOLOv3 中,nms 函数用于过滤检测结果中的重叠框,保留置信度最高的框。nms 函数的具体实现可以参考以下代码:
```python
def non_max_suppression(prediction, confidence, num_classes, nms_conf=0.4):
"""
使用非极大值抑制过滤检测结果中的重叠框
"""
# 获取所有检测框的左上角和右下角坐标
box_corner = prediction.new(prediction.shape)
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
prediction[:, :, :4] = box_corner[:, :, :4]
output = [None for _ in range(len(prediction))]
# 对每个图片的检测结果进行处理
for image_i, image_pred in enumerate(prediction):
# 获取该图片中置信度最高的类别
max_conf, max_conf_idx = torch.max(image_pred[:, 5:5 + num_classes], 1)
max_conf = max_conf.float().unsqueeze(1)
max_conf_idx = max_conf_idx.float().unsqueeze(1)
# 将置信度最高的类别和置信度分数合并到一起
seq = (image_pred[:, :5], max_conf, max_conf_idx)
image_pred = torch.cat(seq, 1)
# 删除置信度小于阈值的框
non_zero_indices = (torch.nonzero(image_pred[:, 4]))
try:
image_pred_ = image_pred[non_zero_indices.squeeze(), :].view(-1, 7)
except:
continue
# 获取所有检测框的面积
img_classes = unique(image_pred_[:, -1])
for cls in img_classes:
# 获取该类别的检测结果
cls_mask = image_pred_ * (image_pred_[:, -1] == cls).float().unsqueeze(1)
class_mask_idx = torch.nonzero(cls_mask[:, -2]).squeeze()
image_pred_class = image_pred_[class_mask_idx].view(-1, 7)
# 按照置信度排序
conf_sort_index = torch.sort(image_pred_class[:, 4], descending=True)[1]
image_pred_class = image_pred_class[conf_sort_index]
idx = image_pred_class.size(0) # 检测框的数量
for i in range(idx):
# 计算当前检测框与其它检测框的 IoU
try:
ious = bbox_iou(image_pred_class[i].unsqueeze(0), image_pred_class[i + 1:])
except ValueError:
break
except IndexError:
break
# 将 IoU 大于阈值的检测框删除
iou_mask = (ious < nms_conf).float().unsqueeze(1)
image_pred_class[i + 1:] *= iou_mask
non_zero_indices = torch.nonzero(image_pred_class[:, 4]).squeeze()
image_pred_class = image_pred_class[non_zero_indices].view(-1, 7)
batch_index = image_pred_class.new(image_pred_class.size(0), 1).fill_(image_i)
# 将图片索引加入到检测框中
seq = batch_index, image_pred_class
if not output[image_i]:
output[image_i] = torch.cat(seq, 1)
else:
output[image_i] = torch.cat((output[image_i], torch.cat(seq, 1)))
return output
```
其中,bbox_iou 函数用于计算两个框的 IoU 值,unique 函数用于获取一个张量中的所有唯一值。使用时,可以将预测结果和对应的置信度输入到 non_max_suppression 函数中进行处理,如下所示:
```python
# 获取预测结果和置信度
prediction = model(x)
conf_mask = (prediction[:, :, 4] > confidence_threshold).float().unsqueeze(2)
prediction = prediction * conf_mask
# 过滤掉重叠的检测框
output = non_max_suppression(prediction, confidence_threshold, num_classes)
```
阅读全文