fcos 导出onnx之后的 后处理代码c++实现
时间: 2024-05-15 21:13:20 浏览: 226
FCOS 检测器的后处理是将网络输出的 raw prediction 转换为检测结果的过程。下面给出在 C++ 中实现 FCOS 后处理的代码:
```c++
#include <iostream>
#include <vector>
#include <algorithm>
#include <cmath>
// 用于存储检测结果的结构体
struct DetectionResult {
float x1, y1, x2, y2, score;
};
// 用于计算 IOU 的函数
float iou(const DetectionResult& a, const DetectionResult& b) {
float area_a = (a.x2 - a.x1 + 1) * (a.y2 - a.y1 + 1);
float area_b = (b.x2 - b.x1 + 1) * (b.y2 - b.y1 + 1);
float x1 = std::max(a.x1, b.x1);
float y1 = std::max(a.y1, b.y1);
float x2 = std::min(a.x2, b.x2);
float y2 = std::min(a.y2, b.y2);
if (x1 > x2 || y1 > y2) {
return 0.0;
}
float intersection_area = (x2 - x1 + 1) * (y2 - y1 + 1);
float union_area = area_a + area_b - intersection_area;
return intersection_area / union_area;
}
// FCOS 后处理函数
std::vector<DetectionResult> fcos_postprocess(float* cls_logits, float* ctr_logits, float* bbox_preds,
int num_classes, int num_anchors, int image_width, int image_height,
float score_threshold, float nms_threshold) {
// 计算 feature map 的大小
int feature_map_size = sqrt(num_anchors);
std::vector<DetectionResult> detections;
// 对每个 anchor 进行处理
for (int i = 0; i < num_anchors; ++i) {
// 获取分类概率、中心偏移量和宽高偏移量
float* cls_scores = cls_logits + i * num_classes;
float* ctr_scores = ctr_logits + i;
float* bbox_deltas = bbox_preds + i * 4;
// 获取 anchor 的坐标
int x = i % feature_map_size;
int y = i / feature_map_size;
float x_center = (x + 0.5) * image_width / feature_map_size;
float y_center = (y + 0.5) * image_height / feature_map_size;
float width = (float)image_width / feature_map_size;
float height = (float)image_height / feature_map_size;
// 计算 anchor 的左上角和右下角坐标
float x1 = x_center - width / 2;
float y1 = y_center - height / 2;
float x2 = x_center + width / 2;
float y2 = y_center + height / 2;
// 计算分类概率
float max_cls_score = cls_scores[0];
int max_cls_idx = 0;
for (int j = 1; j < num_classes; ++j) {
if (cls_scores[j] > max_cls_score) {
max_cls_score = cls_scores[j];
max_cls_idx = j;
}
}
// 计算置信度
float ctr_score = sigmoid(ctr_scores[0]);
float confidence = max_cls_score * ctr_score;
// 如果置信度大于阈值,进行检测
if (confidence > score_threshold) {
// 计算偏移量
float dx = bbox_deltas[0];
float dy = bbox_deltas[1];
float dw = bbox_deltas[2];
float dh = bbox_deltas[3];
// 计算预测框的左上角和右下角坐标
float pred_x1 = x_center + dx * width - 0.5 * exp(dw) * width;
float pred_y1 = y_center + dy * height - 0.5 * exp(dh) * height;
float pred_x2 = x_center + dx * width + 0.5 * exp(dw) * width;
float pred_y2 = y_center + dy * height + 0.5 * exp(dh) * height;
// 将结果加入检测结果数组
DetectionResult detection;
detection.x1 = pred_x1;
detection.y1 = pred_y1;
detection.x2 = pred_x2;
detection.y2 = pred_y2;
detection.score = confidence;
detections.push_back(detection);
}
}
// 对检测结果进行非极大值抑制
std::vector<DetectionResult> final_detections;
std::sort(detections.begin(), detections.end(), [](const DetectionResult& a, const DetectionResult& b) {
return a.score > b.score;
});
std::vector<bool> is_suppressed(detections.size(), false);
for (int i = 0; i < detections.size(); ++i) {
if (!is_suppressed[i]) {
final_detections.push_back(detections[i]);
for (int j = i + 1; j < detections.size(); ++j) {
if (!is_suppressed[j] && iou(detections[i], detections[j]) > nms_threshold) {
is_suppressed[j] = true;
}
}
}
}
return final_detections;
}
```
需要注意的是,这里的 `sigmoid` 函数是将输入映射到 [0, 1] 区间的函数:
```c++
float sigmoid(float x) {
return 1.0 / (1.0 + exp(-x));
}
```
阅读全文