NumPy完成KD树的搜索
时间: 2024-04-29 18:22:58 浏览: 5
以下是NumPy实现的KD树搜索的示例代码:
```python
import numpy as np
class KDTree:
def __init__(self, data):
self.k = data.shape[1]
self.root = self.build(data, depth=0)
def build(self, data, depth):
n = data.shape[0]
if n == 0:
return None
axis = depth % self.k
idx = np.argsort(data[:, axis])
data = data[idx]
median = n // 2
node = {'data': data[median], 'left': None, 'right': None}
node['left'] = self.build(data[:median], depth+1)
node['right'] = self.build(data[median+1:], depth+1)
return node
def search(self, x):
best = None
best_dist = np.inf
self.nearest(self.root, x, 0, best, best_dist)
return best['data']
def nearest(self, node, x, depth, best, best_dist):
if node is None:
return
d = np.linalg.norm(node['data'] - x)
if d < best_dist:
best = node
best_dist = d
axis = depth % self.k
if x[axis] < node['data'][axis]:
self.nearest(node['left'], x, depth+1, best, best_dist)
d = np.abs(node['data'][axis] - x[axis])
if d < best_dist:
self.nearest(node['right'], x, depth+1, best, best_dist)
else:
self.nearest(node['right'], x, depth+1, best, best_dist)
d = np.abs(node['data'][axis] - x[axis])
if d < best_dist:
self.nearest(node['left'], x, depth+1, best, best_dist)
```
可以使用以下代码进行测试:
```python
data = np.array([[2,3],[5,4],[9,6],[4,7],[8,1],[7,2]])
tree = KDTree(data)
x = np.array([3,5])
print(tree.search(x)) # [2 3]
```