写一段c++的kd-tree代码
时间: 2023-05-26 09:01:50 浏览: 59
以下是一个简单的kd-tree代码示例,仅供参考:
```
#include <stdio.h>
#include <stdlib.h>
#include <math.h>
// 定义kd-tree节点
typedef struct kd_node_t
{
double point[2];
struct kd_node_t *left, *right;
} kd_node_t;
// 计算两点距离的平方
static inline double dist_sq(double x1, double y1, double x2, double y2)
{
const double dx = x1 - x2;
const double dy = y1 - y2;
return dx * dx + dy * dy;
}
// 以x轴为比较标准对点进行排序
static int cmp_x(const void *a, const void *b)
{
const double x1 = (*(kd_node_t **) a)->point[0];
const double x2 = (*(kd_node_t **) b)->point[0];
if (x1 < x2)
return -1;
else if (x1 > x2)
return 1;
else
return 0;
}
// 以y轴为比较标准对点进行排序
static int cmp_y(const void *a, const void *b)
{
const double y1 = (*(kd_node_t **) a)->point[1];
const double y2 = (*(kd_node_t **) b)->point[1];
if (y1 < y2)
return -1;
else if (y1 > y2)
return 1;
else
return 0;
}
// 建立kd-tree,返回根节点
kd_node_t *kd_build(kd_node_t **nodes, int n, int depth)
{
if (n <= 0)
return NULL;
// 判断当前深度是x轴比较还是y轴比较
const int axis = depth % 2;
if (n == 1)
return nodes[0];
// 以x或y轴为比较标准排序
if (axis == 0)
qsort(nodes, n, sizeof(kd_node_t *), cmp_x);
else
qsort(nodes, n, sizeof(kd_node_t *), cmp_y);
// 选取中位数节点作为当前节点
const int mid = n / 2;
kd_node_t *node = nodes[mid];
// 递归建立左右子树
node->left = kd_build(nodes, mid, depth + 1);
node->right = kd_build(nodes + mid + 1, n - mid - 1, depth + 1);
return node;
}
// 寻找最近邻点,返回最近邻点的距离的平方
double kd_nearest(kd_node_t *node, double x, double y, kd_node_t **best_node, double best_dist)
{
if (!node)
return best_dist;
const double dist = dist_sq(node->point[0], node->point[1], x, y);
if (dist < best_dist)
{
*best_node = node;
best_dist = dist;
}
const int axis = node->point[0] - x < 0;
if (axis == 0)
{
best_dist = kd_nearest(node->left, x, y, best_node, best_dist);
if (x - node->point[0] <= sqrt(best_dist))
best_dist = kd_nearest(node->right, x, y, best_node, best_dist);
}
else
{
best_dist = kd_nearest(node->right, x, y, best_node, best_dist);
if (node->point[0] - x <= sqrt(best_dist))
best_dist = kd_nearest(node->left, x, y, best_node, best_dist);
}
return best_dist;
}
int main()
{
const int N = 10; // 点的数量
kd_node_t *nodes[N];
srand(911);
// 随机生成N个点
for (int i = 0; i < N; ++i)
{
nodes[i] = malloc(sizeof(kd_node_t));
nodes[i]->point[0] = rand() % 20;
nodes[i]->point[1] = rand() % 20;
nodes[i]->left = NULL;
nodes[i]->right = NULL;
}
// 建立kd-tree
kd_node_t *root = kd_build(nodes, N, 0);
// 寻找最近邻点
const double x = rand() % 20;
const double y = rand() % 20;
kd_node_t *nearest_node = NULL;
double nearest_dist = kd_nearest(root, x, y, &nearest_node, INFINITY);
printf("查找点 (%.2f, %.2f) 的最近邻点为 (%.2f, %.2f),距离为 %.2f。\n", x, y, nearest_node->point[0], nearest_node->point[1], sqrt(nearest_dist));
// 释放内存
for (int i = 0; i < N; ++i)
free(nodes[i]);
return 0;
}
```