python 画kd树
时间: 2023-08-28 12:06:00 浏览: 36
以下是一个简单的 Python 代码示例,用于绘制 KD 树:
```python
import matplotlib.pyplot as plt
import numpy as np
class Node:
def __init__(self, data=None, split=None, left=None, right=None):
self.data = data
self.split = split
self.left = left
self.right = right
class KDTree:
def __init__(self, points):
self.root = self.build(points)
def build(self, points, depth=0):
n = len(points)
if n <= 0:
return None
k = len(points[0])
axis = depth % k
sorted_points = sorted(points, key=lambda point: point[axis])
mid = n // 2
return Node(data=sorted_points[mid],
split=axis,
left=self.build(sorted_points[:mid], depth + 1),
right=self.build(sorted_points[mid+1:], depth + 1))
def search(self, x, k=1):
best = []
def visit(node):
if node is None:
return
dx = node.data - x
dist = np.sqrt(np.sum(dx ** 2))
if len(best) < k:
best.append((dist, node.data))
best.sort()
elif dist < best[-1][0]:
best.pop()
best.append((dist, node.data))
best.sort()
axis = node.split
if dx[axis] ** 2 < best[-1][0]:
visit(node.left)
visit(node.right)
elif x[axis] < node.data[axis]:
visit(node.left)
else:
visit(node.right)
visit(self.root)
return [item[1] for item in best]
def plot(self, ax=None, min_x=None, max_x=None, min_y=None, max_y=None):
if ax is None:
fig, ax = plt.subplots()
if min_x is None:
min_x = float('inf')
if max_x is None:
max_x = float('-inf')
if min_y is None:
min_y = float('inf')
if max_y is None:
max_y = float('-inf')
def visit(node, x_range, y_range):
if node is None:
return
x, y = node.data
ax.scatter(x, y, color='black')
if node.split == 0:
ax.plot([x, x], y_range, color='gray', alpha=0.5)
visit(node.left, [x_range[0], x], y_range)
visit(node.right, [x, x_range[1]], y_range)
else:
ax.plot(x_range, [y, y], color='gray', alpha=0.5)
visit(node.left, x_range, [y_range[0], y])
visit(node.right, x_range, [y, y_range[1]])
if x < min_x:
min_x = x
if x > max_x:
max_x = x
if y < min_y:
min_y = y
if y > max_y:
max_y = y
visit(self.root, [min_x, max_x], [min_y, max_y])
ax.set_xlim([min_x-1, max_x+1])
ax.set_ylim([min_y-1, max_y+1])
ax.set_aspect('equal')
ax.tick_params(axis='both', which='both', length=0)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
ax.spines['bottom'].set_color('gray')
ax.spines['left'].set_color('gray')
return ax
# 示例代码
points = np.array([(2,3), (5,4), (9,6), (4,7), (8,1), (7,2)])
tree = KDTree(points)
print(tree.search(np.array([4,4])))
tree.plot()
plt.show()
```
注:本示例代码中的 KD 树构建和搜索算法采用了暴力实现,不是最优解。