那如果我要得到模型的FPS是要在train.py里得到,还是test.py还是detect.py?
时间: 2023-12-10 15:05:20 浏览: 215
在 YOLOv7 中,获取模型的 FPS(每秒帧数)应该在 `detect.py` 中进行,因为 `detect.py` 是用于测试模型推理速度和准确率的脚本。
在 `detect.py` 中,可以通过设置 `--source` 参数指定输入的图片、视频或摄像头视频流,通过设置 `--weights` 参数指定模型权重文件,然后通过调用 `detect` 函数进行模型推理。在推理过程中,可以使用 `time.time()` 计算推理时间,从而计算出模型的 FPS。
以下是一个简单的计算 FPS 的示例代码:
```python
import time
import cv2
from models.experimental import attempt_load
from utils.general import non_max_suppression, scale_coords
from utils.torch_utils import select_device
# 设置输入和输出路径
source = 'path/to/input' # 输入路径
output = 'path/to/output' # 输出路径
# 加载模型
weights = 'path/to/weights' # 模型权重文件
device = select_device('cpu') # 设备
model = attempt_load(weights, map_location=device) # 加载模型
stride = int(model.stride.max()) # 计算 stride
# 打开输入和输出文件
cap = cv2.VideoCapture(source) # 打开输入文件
out = cv2.VideoWriter(output, cv2.VideoWriter_fourcc(*'mp4v'), 30, (640, 480)) # 打开输出文件
# 推理循环
t0 = time.time()
while True:
# 读取一帧
ret, img = cap.read()
if not ret:
break
# 图像预处理
img = cv2.resize(img, (640, 480))
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = img.transpose(2, 0, 1) # HWC -> CHW
img = img / 255.0 # 归一化
img = torch.from_numpy(img).float().to(device) # 转换为 tensor
# 模型推理
t1 = time.time()
pred = model(img.unsqueeze(0))[0] # 推理
pred = non_max_suppression(pred, conf_thres=0.25, iou_thres=0.45) # 后处理
t2 = time.time()
# 绘制结果
for det in pred[0]:
if det is not None and len(det):
det[:, :4] = scale_coords(img.shape[2:], det[:, :4], img.shape[1:]).round()
for *xyxy, conf, cls in reversed(det):
label = f'{names[int(cls)]} {conf:.2f}'
plot_one_box(xyxy, img0, label=label, color=colors[int(cls)], line_thickness=3)
# 计算 FPS
fps = 1.0 / (t2 - t1)
# 显示结果
cv2.putText(img0, f'FPS: {fps:.1f}', (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (0, 0, 255), 2)
cv2.imshow('result', img0)
out.write(img0)
# 检查是否按下了 ESC 键
if cv2.waitKey(1) == 27:
break
# 释放资源
cap.release()
out.release()
cv2.destroyAllWindows()
t3 = time.time()
print(f'FPS: {1.0 / ((t3 - t0) / cap.get(cv2.CAP_PROP_FRAME_COUNT)):.1f}')
```
在这个示例代码中,我们使用了 `cv2.VideoCapture` 打开输入文件,使用了 `cv2.VideoWriter` 打开输出文件,使用了 `attempt_load` 函数加载模型,使用了 `non_max_suppression` 函数进行后处理,使用了 `scale_coords` 函数将坐标从缩放后的大小转换回原始大小,使用了 `plot_one_box` 函数绘制检测结果。在推理循环中,我们用 `time.time()` 计算每一帧的推理时间,并计算出模型的 FPS。最后,我们使用 `cv2.destroyAllWindows()` 关闭所有窗口,并计算出整个视频的平均 FPS。
阅读全文