yolov5-7.0代码metrics
时间: 2024-12-26 21:26:54 浏览: 8
### YOLOv5 7.0 Metrics Code Implementation
YOLOv5 7.0版本中的`metrics.py`文件负责计算模型评估的各种指标,这些指标对于理解模型性能至关重要。以下是实现细节:
#### 计算精度和召回率
为了计算精度(Precision)和召回率(Recall),代码会遍历预测框并与真实标签进行匹配。
```python
def process_batch(detections, labels):
"""
Return correct predictions matrix. Both sets of boxes are in (x1, y1, x2, y2) format.
Arguments:
detections (Array[N, 6]): the output of the model with xyxy format bounding boxes,
confidence and class score for each detection.
labels (Array[M, 5]): ground truth labels containing label indices and xywh coordinates.
Returns:
Array[N]: True/False values indicating whether a prediction is correct or not.
"""
iou = box_iou(labels[:, 1:], detections[:, :4]) # IOU between all true_labels and detected_boxes
correct_class = labels[:, 0:1] == detections[:, 5] # Check classes match
# Find best overlap for every true_label
max_overlap_indices = torch.argmax(iou * correct_class.float(), dim=1).long()
matches = torch.zeros(len(detections), dtype=torch.bool)
for idx, det_idx in enumerate(max_overlap_indices):
if iou[idx, det_idx] >= 0.5 and not matches[det_idx]:
matches[det_idx] = True
return matches
```
此函数接受两个参数:一个是来自测试集的真实边界框及其类别标签;另一个是从模型得到的检测结果列表。通过计算交并比(IOU)来判断哪些预测是正确的[^1]。
#### 绘制PR曲线
绘制精确度-召回率(PR)曲线有助于直观展示不同阈值下的模型表现。
```python
import matplotlib.pyplot as plt
def plot_pr_curve(px, py, ap, save_dir='pr_curve.png', names=()):
fig, ax = plt.subplots(1, 1, figsize=(9, 6))
ax.plot(px, py, linewidth=3, color='blue')
ax.set_xlabel('Recall')
ax.set_ylabel('Precision')
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.grid(True)
title_string = f'AP@{ap:.3f}'
if isinstance(names, dict):
title_string += ' per Class'
elif len(names) > 0:
title_string += ': '
for k, v in names.items():
title_string += f'{k}={v}, '
ax.set_title(title_string.strip(', '))
fig.savefig(save_dir, dpi=250)
plt.close(fig)
```
这段代码接收一系列点作为输入,并创建一张图表保存到指定路径下。图中展示了随着召回率增加时对应的平均精度(AP)。
#### 主要评价逻辑
最后,在主程序里调用了上述辅助方法完成整个验证过程。
```python
from pathlib import Path
import numpy as np
import torch
from tqdm import tqdm
def main(opt):
device = select_device(opt.device)
model = attempt_load(weights, map_location=device) # Load FP32 model from weights.pth
dataset = create_dataloader(imgsz=img_size, batch_size=batch_size, stride=stride, single_cls=single_cls)
stats, ap_per_class = [], []
seen = 0
pbar = tqdm(total=len(dataset), desc=f'Evaluating {Path(data_yaml).stem}')
for batch_i, (imgs, targets, paths, shapes) in enumerate(dataset):
imgs = imgs.to(device, non_blocking=True).float() / 255.0 # uint8 to float32, 0 - 255 to 0.0 - 1.0
out = model(imgs) # Inference
pred = non_max_suppression(out, conf_thres=opt.conf_thres, iou_thres=opt.iou_thres)
tcls = targets[:, 1].cpu().numpy()
tpred = [item.cpu().numpy() for item in pred]
results = process_batch(tpred, tcls)
...
```
这部分脚本定义了一个名为main的函数,该函数加载预训练权重、设置数据加载器,并迭代每一批次的数据来进行推理。随后利用之前提到的过程批量处理预测结果与实际标注之间的对比分析。
阅读全文