SHARE
TWEET

Untitled

a guest Sep 19th, 2019 78 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. import numpy as np
  2. from time import time
  3. from scipy.spatial import KDTree as kd
  4. from functools import reduce
  5. import matplotlib.pyplot as plt
  6.  
  7. def euclid(c, cs, r):
  8.     return ((cs[:,0] - c[0]) ** 2 + (cs[:,1] - c[1]) ** 2 + (cs[:,2] - c[2]) ** 2) < r ** 2
  9.  
  10. def find_nn_naive(cells, radius):
  11.     matrix = np.zeros((cells.shape[0] ** 2,2))
  12.     mi = 0
  13.     for i in range(len(cells)):
  14.         cell = cells[i]
  15.         cands = euclid(cell, cells, radius)
  16.         ids = np.where(cands)[0]
  17.         ids = ids[ids != i]
  18.         for j in range(len(ids)):
  19.             matrix[mi + j, 0] = i
  20.             matrix[mi + j, 1] = ids[j]
  21.         mi += len(ids)
  22.     return matrix[0:mi,:]
  23.  
  24. def find_nn_kd_seminaive(cells, radius):
  25.     tree = kd(cells)
  26.     matrix = np.zeros((cells.shape[0] ** 2,2))
  27.     mi = 0
  28.     for i in range(len(cells)):
  29.         res = tree.query_ball_point(cells[i], radius)
  30.         if len(res) > 1:
  31.             ids = res[1:]
  32.             for j in range(len(ids)):
  33.                 matrix[mi + j, 0] = i
  34.                 matrix[mi + j, 1] = ids[j]
  35.             mi += len(ids)
  36.     return matrix[0:mi,:]
  37.  
  38. def find_nn_kd_by_tree(cells, radius):
  39.     tree = kd(cells)
  40.     res = list(filter(lambda x: len(x) > 1, tree.query_ball_tree(tree, radius)))
  41.     matrix = np.zeros((reduce(lambda count, x: count + len(x) - 1, res, 0), 2))
  42.     mi = 0
  43.     for j in range(len(res)):
  44.         for i in range(1, len(res[j])):
  45.             matrix[mi, 0] = res[j][0]
  46.             matrix[mi, 1] = res[j][i]
  47.             mi += 1
  48.  
  49.     return matrix[0:mi,:]
  50.  
  51. min_iter = 5000
  52. max_iter = 10000
  53. step_iter = 1000
  54.  
  55. rng = range(min_iter, max_iter, step_iter)
  56. elapsed_naive = np.zeros(len(rng))
  57. elapsed_kd_sn = np.zeros(len(rng))
  58. elapsed_kd_tr = np.zeros(len(rng))
  59. shapes = np.zeros((len(rng), 3))
  60. ei = 0
  61. for i in rng:
  62.     random_cells = np.random.rand(i, 3) * 400.
  63.     t = time()
  64.     r1 = find_nn_naive(random_cells, 50.)
  65.     shapes[ei, 0] = r1.shape[0]
  66.     elapsed_naive[ei] = time() - t
  67.     # print('naive shape:', r1.shape)
  68.     # print('naive time:', elapsed_naive[ei])
  69.     t = time()
  70.     r2 = find_nn_kd_seminaive(random_cells, 50.)
  71.     shapes[ei, 1] = r2.shape[0]
  72.     elapsed_kd_sn[ei] = time() - t
  73.     # print('seminaive shape:', r2.shape)
  74.     # print('seminaive time:', elapsed_kd_sn[ei])
  75.     t = time()
  76.     r3 = find_nn_kd_by_tree(random_cells, 50.)
  77.     shapes[ei, 2] = r3.shape[0]
  78.     elapsed_kd_tr[ei] = time() - t
  79.     # print('tree shape:', r3.shape)
  80.     # print('tree time:', elapsed_kd_tr[ei])
  81.     # print(r1,r2,r3)
  82.     # exit()
  83.     ei += 1
  84.  
  85. # Plot result comparison: Do all 3 implementations yield the same result? -> 3 overlapping lines
  86. plt.plot(rng, shapes[:,0], label='naive')
  87. plt.plot(rng, shapes[:,1], label='semi kd')
  88. plt.plot(rng, shapes[:,2], label='full kd')
  89. plt.legend()
  90. plt.show(block=True)
  91.  
  92. # What's the runtime for each?
  93. plt.plot(rng, elapsed_naive, label='naive')
  94. plt.plot(rng, elapsed_kd_sn, label='semi kd')
  95. plt.plot(rng, elapsed_kd_tr, label='full kd')
  96. plt.legend()
  97. plt.show(block=True)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top