yolo7检测框输出为中心点代码
时间: 2024-02-12 07:05:36 浏览: 144
以下是使用YOLOv7检测图像中的物体并输出检测框中心点坐标和宽高的Python代码示例:
```
import torch
import cv2
import numpy as np
from yolov7.utils.datasets import letterbox
from yolov7.models.experimental import attempt_load
from yolov7.utils.general import non_max_suppression, scale_coords, xyxy2xywh
# 加载YOLOv7模型
weights = 'yolov7s.pt'
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = attempt_load(weights, map_location=device)
# 加载物体类别标签
classes = []
with open("coco.names", "r") as f:
classes = [line.strip() for line in f.readlines()]
# 设置输入图像尺寸和缩放因子
input_size = 640
scale_factor = 1/255.0
# 读取图像并进行预处理
image = cv2.imread("image.jpg")
height, width = image.shape[:2]
image = letterbox(image, new_shape=input_size)[0]
image = image[:, :, ::-1].transpose(2, 0, 1).copy()
image = torch.from_numpy(image).float().div(255.0).unsqueeze(0)
# 将预处理后的图像输入到网络中进行推理
model.eval()
with torch.no_grad():
output = model(image.to(device))[0]
output = non_max_suppression(output, conf_thres=0.5, iou_thres=0.5)
# 解析输出层并筛选出置信度较高的物体框
boxes = []
confidences = []
class_ids = []
if output[0] is not None:
for detection in output[0]:
x1y1 = (detection[:2] * input_size).int()
x2y2 = (detection[2:4] * input_size).int()
box = torch.cat([x1y1, x2y2], dim=-1).float()
box = scale_coords(image.shape[2:], box, image.shape[2:]).tolist()
x, y, w, h = xyxy2xywh(torch.tensor(box))[0].tolist()
center_x = x + w/2
center_y = y + h/2
boxes.append([x, y, w, h])
confidences.append(float(detection[4]))
class_ids.append(int(detection[5]))
# 应用非最大抑制(NMS)算法筛选出重叠度较小的物体框
indices = cv2.dnn.NMSBoxes(boxes, confidences, 0.5, 0.4)
# 绘制筛选后的物体框并输出中心点坐标和宽高
for i in indices:
i = i[0]
x, y, w, h = boxes[i]
center_x = x + w/2
center_y = y + h/2
print(classes[class_ids[i]], center_x, center_y, w, h)
cv2.rectangle(image, (int(x), int(y)), (int(x + w), int(y + h)), (0, 255, 0), 2)
cv2.putText(image, classes[class_ids[i]], (int(x), int(y)-5), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
# 显示输出图像
cv2.imshow("YOLOv7 Object Detection", image)
cv2.waitKey(0)
cv2.destroyAllWindows()
```
其中,`yolov7s.pt`是YOLOv7模型的权重文件,`coco.names`是物体类别标签文件,`image.jpg`是待检测的图像。可以根据实际需求修改这些文件的路径和名称。在运行代码前需要确保已安装PyTorch库、OpenCV库和Numpy库,并将yolov7目录添加到Python路径中。
阅读全文