Guest User

Untitled

a guest
Sep 19th, 2019
233
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.00 KB | None | 0 0
  1. import numpy as np
  2. from time import time
  3. from scipy.spatial import KDTree as kd
  4. from functools import reduce
  5. import matplotlib.pyplot as plt
  6.  
  7. def euclid(c, cs, r):
  8.     return ((cs[:,0] - c[0]) ** 2 + (cs[:,1] - c[1]) ** 2 + (cs[:,2] - c[2]) ** 2) < r ** 2
  9.  
  10. def find_nn_naive(cells, radius):
  11.     matrix = np.zeros((cells.shape[0] ** 2,2))
  12.     mi = 0
  13.     for i in range(len(cells)):
  14.         cell = cells[i]
  15.         cands = euclid(cell, cells, radius)
  16.         ids = np.where(cands)[0]
  17.         ids = ids[ids != i]
  18.         for j in range(len(ids)):
  19.             matrix[mi + j, 0] = i
  20.             matrix[mi + j, 1] = ids[j]
  21.         mi += len(ids)
  22.     return matrix[0:mi,:]
  23.  
  24. def find_nn_kd_seminaive(cells, radius):
  25.     tree = kd(cells)
  26.     matrix = np.zeros((cells.shape[0] ** 2,2))
  27.     mi = 0
  28.     for i in range(len(cells)):
  29.         res = tree.query_ball_point(cells[i], radius)
  30.         if len(res) > 1:
  31.             ids = res[1:]
  32.             for j in range(len(ids)):
  33.                 matrix[mi + j, 0] = i
  34.                 matrix[mi + j, 1] = ids[j]
  35.             mi += len(ids)
  36.     return matrix[0:mi,:]
  37.  
  38. def find_nn_kd_by_tree(cells, radius):
  39.     tree = kd(cells)
  40.     res = list(filter(lambda x: len(x) > 1, tree.query_ball_tree(tree, radius)))
  41.     matrix = np.zeros((reduce(lambda count, x: count + len(x) - 1, res, 0), 2))
  42.     mi = 0
  43.     for j in range(len(res)):
  44.         for i in range(1, len(res[j])):
  45.             matrix[mi, 0] = res[j][0]
  46.             matrix[mi, 1] = res[j][i]
  47.             mi += 1
  48.  
  49.     return matrix[0:mi,:]
  50.  
  51. min_iter = 5000
  52. max_iter = 10000
  53. step_iter = 1000
  54.  
  55. rng = range(min_iter, max_iter, step_iter)
  56. elapsed_naive = np.zeros(len(rng))
  57. elapsed_kd_sn = np.zeros(len(rng))
  58. elapsed_kd_tr = np.zeros(len(rng))
  59. shapes = np.zeros((len(rng), 3))
  60. ei = 0
  61. for i in rng:
  62.     random_cells = np.random.rand(i, 3) * 400.
  63.     t = time()
  64.     r1 = find_nn_naive(random_cells, 50.)
  65.     shapes[ei, 0] = r1.shape[0]
  66.     elapsed_naive[ei] = time() - t
  67.     # print('naive shape:', r1.shape)
  68.     # print('naive time:', elapsed_naive[ei])
  69.     t = time()
  70.     r2 = find_nn_kd_seminaive(random_cells, 50.)
  71.     shapes[ei, 1] = r2.shape[0]
  72.     elapsed_kd_sn[ei] = time() - t
  73.     # print('seminaive shape:', r2.shape)
  74.     # print('seminaive time:', elapsed_kd_sn[ei])
  75.     t = time()
  76.     r3 = find_nn_kd_by_tree(random_cells, 50.)
  77.     shapes[ei, 2] = r3.shape[0]
  78.     elapsed_kd_tr[ei] = time() - t
  79.     # print('tree shape:', r3.shape)
  80.     # print('tree time:', elapsed_kd_tr[ei])
  81.     # print(r1,r2,r3)
  82.     # exit()
  83.     ei += 1
  84.  
  85. # Plot result comparison: Do all 3 implementations yield the same result? -> 3 overlapping lines
  86. plt.plot(rng, shapes[:,0], label='naive')
  87. plt.plot(rng, shapes[:,1], label='semi kd')
  88. plt.plot(rng, shapes[:,2], label='full kd')
  89. plt.legend()
  90. plt.show(block=True)
  91.  
  92. # What's the runtime for each?
  93. plt.plot(rng, elapsed_naive, label='naive')
  94. plt.plot(rng, elapsed_kd_sn, label='semi kd')
  95. plt.plot(rng, elapsed_kd_tr, label='full kd')
  96. plt.legend()
  97. plt.show(block=True)
Add Comment
Please, Sign In to add comment