yolov5后处理 python代码
时间: 2023-08-24 09:09:44 浏览: 249
以下是一个简单的 YOLOv5 后处理 Python 代码示例:
```python
import torch
from numpy import random
def non_max_suppression(prediction, conf_thres=0.1, iou_thres=0.6, classes=None, agnostic=False):
"""
对 YOLOv5 预测的边界框进行非极大值抑制(NMS)。
prediction: YOLOv5 模型的预测输出,包含所有检测边界框的信息。
conf_thres: 置信度阈值,低于该值的边界框将被忽略。
iou_thres: IOU 阈值,高于该值的边界框将被视为重叠,并进行 NMS 处理。
classes: 只保留指定类别的边界框,如果为 None,则保留所有类别的边界框。
agnostic: 是否对类别进行融合,即不考虑类别信息。
返回值:经过 NMS 处理后的边界框信息。
"""
# 从预测结果中提取边界框信息
box_corner = prediction[:, :, :4]
box_wh = box_corner[:, :, 2:4] - box_corner[:, :, :2]
box_area = box_wh[..., 0] * box_wh[..., 1]
box_center = (box_corner[:, :, 2:4] + box_corner[:, :, :2]) / 2
# 根据置信度进行筛选
scores = prediction[:, :, 4]
score_mask = scores > conf_thres
# 如果没有符合条件的边界框则返回空列表
if score_mask.sum() == 0:
return []
# 按照置信度排序
scores = scores[score_mask]
boxes = torch.cat((box_center[score_mask], box_wh[score_mask]), 2)
_, box_sort_idx = torch.sort(scores, descending=True)
boxes = boxes[box_sort_idx]
scores = scores[box_sort_idx]
# 初始化 NMS 结果
keep_boxes = []
# 进行 NMS 处理
while boxes.shape[0] > 0:
current_box = boxes[0]
current_score = scores[0]
keep_boxes.append(current_box)
if boxes.shape[0] == 1:
break
iou = bbox_iou(current_box.unsqueeze(0), boxes[1:])
overlap_mask = iou > iou_thres
if classes is not None and not agnostic:
class_mask = boxes[:, 4] == classes
overlap_mask = overlap_mask & class_mask.unsqueeze(1)
boxes = boxes[~overlap_mask]
scores = scores[~overlap_mask]
return torch.stack(keep_boxes)
def bbox_iou(box1, box2):
"""
计算两个边界框之间的 IOU。
box1: 第一个边界框,可以是一个张量。
box2: 第二个边界框,可以是一个张量或一个张量列表。
返回值:IOU 值。
"""
if box2.ndim == 1:
box2 = box2.unsqueeze(0)
b1_x1, b1_y1, b1_x2, b1_y2 = box1[:, 0], box1[:, 1], box1[:, 2], box1[:, 3]
b2_x1, b2_y1, b2_x2, b2_y2 = box2[:, 0], box2[:, 1], box2[:, 2], box2[:, 3]
inter_x1 = torch.max(b1_x1, b2_x1)
inter_y1 = torch.max(b1_y1, b2_y1)
inter_x2 = torch.min(b1_x2, b2_x2)
inter_y2 = torch.min(b1_y2, b2_y2)
inter_area = (inter_x2 - inter_x1).clamp(0) * (inter_y2 - inter_y1).clamp(0)
box1_area = (b1_x2 - b1_x1) * (b1_y2 - b1_y1)
box2_area = (b2_x2 - b2_x1) * (b2_y2 - b2_y1)
iou = inter_area / (box1_area + box2_area - inter_area)
return iou
def scale_coords(coords, img_shape, pad_shape):
"""
将边界框坐标从缩放后的图像坐标转换为原始图像坐标。
coords: 缩放后的边界框坐标,形状为 (n, 4),其中 n 是边界框的数量。
img_shape: 原始图像的形状,形状为 (height, width)。
pad_shape: 缩放后的图像的形状,形状为 (height, width)。
返回值:转换后的边界框坐标,形状为 (n, 4)。
"""
gain = min(pad_shape[0] / img_shape[0], pad_shape[1] / img_shape[1])
pad = (pad_shape - img_shape * gain) / 2
coords[:, [0, 2]] -= pad[1]
coords[:, [1, 3]] -= pad[0]
coords[:, :4] /= gain
return coords.round().astype(int)
def postprocess(prediction, img_size, conf_thres=0.1, iou_thres=0.6):
"""
对 YOLOv5 模型的预测结果进行后处理,包括 NMS 和转换坐标到原始图像坐标系。
prediction: YOLOv5 模型的预测输出,包含所有检测边界框的信息。
img_size: 原始图像的形状,形状为 (height, width)。
conf_thres: 置信度阈值,低于该值的边界框将被忽略。
iou_thres: IOU 阈值,高于该值的边界框将被视为重叠,并进行 NMS 处理。
返回值:经过 NMS 处理和坐标转换后的边界框信息。
"""
# 进行 NMS 处理
prediction[..., :4] = scale_coords(prediction[..., :4], img_size, img_size)
output = [non_max_suppression(pred, conf_thres, iou_thres) for pred in prediction]
# 将边界框坐标转换为原始图像坐标系
for i, pred in enumerate(output):
for obj in pred:
obj[:4] = scale_coords(obj[:4].unsqueeze(0), img_size, img_size).squeeze()
obj[5] = i
return output
```
这个代码示例包含了三个函数:
- `non_max_suppression()`:对 YOLOv5 预测的边界框进行非极大值抑制(NMS)处理;
- `bbox_iou()`:计算两个边界框之间的 IOU 值;
- `postprocess()`:对 YOLOv5 模型的预测结果进行后处理,包括 NMS 和转换坐标到原始图像坐标系。
阅读全文