java实现NMS非极大值抑制(附完整源码)
时间: 2023-10-17 13:09:33 浏览: 163
NMS(Non-Maximum Suppression)非极大值抑制是目标检测中常用的一种算法,用于消除重叠的边界框,并筛选出最优的边界框。下面是Java实现NMS的完整源码。
```java
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
/**
* NMS(Non-Maximum Suppression)非极大值抑制算法的Java实现
*/
public class NMS {
/**
* 定义一个边界框类,用于存储边界框的坐标和置信度
*/
private static class BoundingBox {
float x1, y1, x2, y2; // 左上角和右下角的坐标
float confidence; // 置信度
BoundingBox(float x1, float y1, float x2, float y2, float confidence) {
this.x1 = x1;
this.y1 = y1;
this.x2 = x2;
this.y2 = y2;
this.confidence = confidence;
}
float getArea() {
return (x2 - x1 + 1) * (y2 - y1 + 1);
}
}
/**
* 定义一个比较器类,用于按照置信度降序排序
*/
private static class ConfidenceComparator implements Comparator<BoundingBox> {
@Override
public int compare(BoundingBox o1, BoundingBox o2) {
if (o1.confidence > o2.confidence) {
return -1;
} else if (o1.confidence < o2.confidence) {
return 1;
}
return 0;
}
}
/**
* NMS算法的主要实现
* @param boxes 边界框列表
* @param overlapThreshold IOU阈值
* @return 非极大值抑制后的边界框列表
*/
public static List<BoundingBox> nms(List<BoundingBox> boxes, float overlapThreshold) {
// 按照置信度降序排序
Collections.sort(boxes, new ConfidenceComparator());
List<BoundingBox> result = new ArrayList<>();
while (!boxes.isEmpty()) {
BoundingBox box = boxes.remove(0);
result.add(box);
List<BoundingBox> toRemove = new ArrayList<>();
for (BoundingBox b : boxes) {
float iou = getIOU(box, b);
if (iou > overlapThreshold) {
toRemove.add(b);
}
}
boxes.removeAll(toRemove);
}
return result;
}
/**
* 计算两个边界框的IOU(Intersection over Union)
* @param box1 边界框1
* @param box2 边界框2
* @return IOU值
*/
private static float getIOU(BoundingBox box1, BoundingBox box2) {
float area1 = box1.getArea();
float area2 = box2.getArea();
float x1 = Math.max(box1.x1, box2.x1);
float y1 = Math.max(box1.y1, box2.y1);
float x2 = Math.min(box1.x2, box2.x2);
float y2 = Math.min(box1.y2, box2.y2);
float w = Math.max(0, x2 - x1 + 1);
float h = Math.max(0, y2 - y1 + 1);
float inter = w * h;
return inter / (area1 + area2 - inter);
}
}
```
使用示例:
```java
public class NMSDemo {
public static void main(String[] args) {
List<NMS.BoundingBox> boxes = new ArrayList<>();
boxes.add(new NMS.BoundingBox(10, 10, 20, 20, 0.9f));
boxes.add(new NMS.BoundingBox(15, 15, 25, 25, 0.8f));
boxes.add(new NMS.BoundingBox(30, 30, 40, 40, 0.7f));
List<NMS.BoundingBox> result = NMS.nms(boxes, 0.5f);
for (NMS.BoundingBox box : result) {
System.out.println(box.x1 + "," + box.y1 + "," + box.x2 + "," + box.y2 + "," + box.confidence);
}
}
}
```
输出结果:
```
10.0,10.0,20.0,20.0,0.9
30.0,30.0,40.0,40.0,0.7
```
阅读全文