Advertisement
Guest User

Untitled

a guest
Nov 18th, 2019
105
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.04 KB | None | 0 0
  1. def distance_squared(a, b):
  2.     return np.linalg.norm(a - b)**2
  3.  
  4. def closer_distance(p, p1, p2):
  5.     if p1 is None:
  6.         return p2
  7.     if p2 is None:
  8.         return p1
  9.    
  10.     if distance_squared(p, p1) < distance_squared(p, p2):
  11.         return p1
  12.     else:
  13.         return p2
  14.    
  15.  
  16.  
  17. class KDTree:
  18.     def _make_kd_tree(self, points, dim, i=0, leaf_size=40):
  19.         if len(points) > leaf_size:
  20.             points = points[points[:,i].argsort()]
  21.             i = (i + 1) % dim
  22.             half = len(points) // 2
  23.             return [
  24.                 self._make_kd_tree(points[: half], dim, i),
  25.                 self._make_kd_tree(points[half + 1:], dim, i),
  26.                 points[half]
  27.             ]
  28.         else:
  29.             return [None, None, points]
  30.        
  31.     def _closest_point(self, cur_node, pivot, dim, i=0):
  32.         if cur_node is None:
  33.             return None
  34.        
  35.         if cur_node[0] is None and cur_node[1] is None:
  36.             points = cur_node[2]
  37.             best = points[0]
  38.             for cur in points:
  39.                 if distance_squared(pivot, best) > distance_squared(pivot, cur):
  40.                     best = cur
  41.                    
  42.             return best
  43.        
  44.         next_branch = None
  45.         opposite_branch = None
  46.        
  47.         if pivot[i] < cur_node[2][i]:
  48.             next_branch = cur_node[0]
  49.             opposite_branch = cur_node[1]
  50.         else:
  51.             next_branch = cur_node[1]
  52.             opposite_branch = cur_node[0]
  53.            
  54.         best = closer_distance(pivot, self._closest_point(next_branch, pivot, dim=dim, i=(i + 1) % dim), cur_node[2])
  55.        
  56.         if distance_squared(pivot, best) > (pivot[i] - cur_node[2][i]) ** 2:
  57.             best = closer_distance(pivot, self._closest_point(opposite_branch, pivot, dim=dim, i=(i + 1) % dim), best)
  58.            
  59.         return best
  60.        
  61.     def __init__(self, X, leaf_size=40):
  62.         self.root = self._make_kd_tree(X, X.shape[1], i=0, leaf_size=leaf_size)
  63.    
  64.     def query(self, X, k=1):
  65.         pass
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement