Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def distance_squared(a, b):
- return np.linalg.norm(a - b)**2
- def closer_distance(p, p1, p2):
- if p1 is None:
- return p2
- if p2 is None:
- return p1
- if distance_squared(p, p1) < distance_squared(p, p2):
- return p1
- else:
- return p2
- class KDTree:
- def _make_kd_tree(self, points, dim, i=0, leaf_size=40):
- if len(points) > leaf_size:
- points = points[points[:,i].argsort()]
- i = (i + 1) % dim
- half = len(points) // 2
- return [
- self._make_kd_tree(points[: half], dim, i),
- self._make_kd_tree(points[half + 1:], dim, i),
- points[half]
- ]
- else:
- return [None, None, points]
- def _closest_point(self, cur_node, pivot, dim, i=0):
- if cur_node is None:
- return None
- if cur_node[0] is None and cur_node[1] is None:
- points = cur_node[2]
- best = points[0]
- for cur in points:
- if distance_squared(pivot, best) > distance_squared(pivot, cur):
- best = cur
- return best
- next_branch = None
- opposite_branch = None
- if pivot[i] < cur_node[2][i]:
- next_branch = cur_node[0]
- opposite_branch = cur_node[1]
- else:
- next_branch = cur_node[1]
- opposite_branch = cur_node[0]
- best = closer_distance(pivot, self._closest_point(next_branch, pivot, dim=dim, i=(i + 1) % dim), cur_node[2])
- if distance_squared(pivot, best) > (pivot[i] - cur_node[2][i]) ** 2:
- best = closer_distance(pivot, self._closest_point(opposite_branch, pivot, dim=dim, i=(i + 1) % dim), best)
- return best
- def __init__(self, X, leaf_size=40):
- self.root = self._make_kd_tree(X, X.shape[1], i=0, leaf_size=leaf_size)
- def query(self, X, k=1):
- pass
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement