ultralytics 实现划线计数
时间: 2024-12-30 22:26:55 浏览: 10
### 使用 Ultralytics YOLOv5 实现划线计数功能
为了实现划线计数功能,可以利用 `ultralytics` 提供的目标检测框架并结合自定义逻辑来统计穿越特定线条的对象数量。下面是一个完整的 Python 脚本示例,展示了如何使用预训练好的 YOLOv5 模型来进行视频流中的对象检测,并通过设定虚拟边界线的方式计算进出区域内的物体数目。
#### 安装依赖库
首先确保安装了必要的软件包:
```bash
pip install ultralytics opencv-python-headless numpy
```
#### 加载模型与初始化参数设置
加载官方发布的轻量级 COCO 数据集上预训练过的权重文件作为基础模型;同时指定要监控的类别以及绘制轨迹所需的配置项。
```python
from ultralytics import YOLO
import cv2
import numpy as np
# Load pre-trained model weights and set device to 'cpu' or 'cuda'
model = YOLO('yolov5s.pt')
line_start = (300, 300) # Define start point of the counting line
line_end = (900, 300) # Define end point of the counting line
up_count = 0 # Counter for objects moving upwards across the line
down_count = 0 # Counter for objects moving downwards across the line
previous_positions = {} # Store previous positions of detected objects
```
#### 处理每一帧图像
对于每一张抓取到的画面,在执行常规推理流程获取预测框之后,还需进一步处理这些结果以便于后续分析。具体来说就是过滤掉不符合条件的结果(比如置信度太低),然后更新各个实例的位置信息用于判断是否越过了给定路径。
```python
def process_frame(frame):
global up_count, down_count
results = model.predict(source=frame, conf=0.4)[0].boxes.data.cpu().numpy()
current_positions = {}
for result in results:
x1, y1, x2, y2, score, cls_id = map(int, result[:6])
centroid_x = int((x1 + x2)/2)
centroid_y = int((y1 + y2)/2)
object_id = f"{cls_id}_{centroid_x}_{centroid_y}"
current_positions[object_id] = (centroid_x, centroid_y)
if object_id not in previous_positions:
continue
prev_centroid_x, prev_centroid_y = previous_positions.get(object_id, (-1,-1))
if prev_centroid_y != -1 and ((prev_centroid_y >= line_start[1]) ^ (centroid_y < line_start[1])):
direction = "UP" if centroid_y < prev_centroid_y else "DOWN"
if direction == "UP":
up_count += 1
elif direction == "DOWN":
down_count += 1
draw_line_and_counts(frame)
update_previous_positions(current_positions)
def draw_line_and_counts(image):
"""Draw a horizontal line on an image."""
thickness = 2
color = (0, 255, 0)
font_scale = 1
text_thickness = 2
org_up = (int(line_start[0]), max(0,int(line_start[1])-20))
org_down = (int(line_start[0]), min(image.shape[0],int(line_start[1])+40))
cv2.line(image, line_start, line_end, color=color, thickness=thickness)
cv2.putText(image,f'Up Count:{str(up_count)}',org_up,cv2.FONT_HERSHEY_SIMPLEX,
fontScale=font_scale,color=(255,0,0),thickness=text_thickness,lineType=cv2.LINE_AA)
cv2.putText(image,f'Down Count:{str(down_count)}',org_down,cv2.FONT_HERSHEY_SIMPLEX,
fontScale=font_scale,color=(255,0,0),thickness=text_thickness,lineType=cv2.LINE_AA)
def update_previous_positions(new_pos_dict):
nonlocal previous_positions
previous_positions.update(new_pos_dict)
```
#### 运行实时监测程序
最后一步是从摄像头或者其他媒体源读入连续的数据流,并调用上述函数完成整个工作循环直至结束。
```python
cap = cv2.VideoCapture('./test.mp4') # Replace with your video file path or camera index like 0.
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
process_frame(frame)
resized_img = cv2.resize(frame,(800,600))
cv2.imshow("Frame",resized_img)
key = cv2.waitKey(1)&0xFF
if key==ord('q'):
break
cv2.destroyAllWindows()
cap.release()
print(f'\nFinal Upward Crossing Objects Number={up_count}')
print(f'Final Downward Crossing Objects Number={down_count}\n')
```
此脚本实现了基本的功能需求——即当有行人或其他移动实体穿过屏幕上的水平虚线时自动累加相应的上下方向通行次数[^1]。
阅读全文