• API
• FAQ
• Tools
• Archive
SHARE
TWEET Untitled a guest Sep 19th, 2019 78 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. import numpy as np
2. from time import time
3. from scipy.spatial import KDTree as kd
4. from functools import reduce
5. import matplotlib.pyplot as plt
6.
7. def euclid(c, cs, r):
8.     return ((cs[:,0] - c) ** 2 + (cs[:,1] - c) ** 2 + (cs[:,2] - c) ** 2) < r ** 2
9.
10. def find_nn_naive(cells, radius):
11.     matrix = np.zeros((cells.shape ** 2,2))
12.     mi = 0
13.     for i in range(len(cells)):
14.         cell = cells[i]
15.         cands = euclid(cell, cells, radius)
16.         ids = np.where(cands)
17.         ids = ids[ids != i]
18.         for j in range(len(ids)):
19.             matrix[mi + j, 0] = i
20.             matrix[mi + j, 1] = ids[j]
21.         mi += len(ids)
22.     return matrix[0:mi,:]
23.
24. def find_nn_kd_seminaive(cells, radius):
25.     tree = kd(cells)
26.     matrix = np.zeros((cells.shape ** 2,2))
27.     mi = 0
28.     for i in range(len(cells)):
29.         res = tree.query_ball_point(cells[i], radius)
30.         if len(res) > 1:
31.             ids = res[1:]
32.             for j in range(len(ids)):
33.                 matrix[mi + j, 0] = i
34.                 matrix[mi + j, 1] = ids[j]
35.             mi += len(ids)
36.     return matrix[0:mi,:]
37.
38. def find_nn_kd_by_tree(cells, radius):
39.     tree = kd(cells)
40.     res = list(filter(lambda x: len(x) > 1, tree.query_ball_tree(tree, radius)))
41.     matrix = np.zeros((reduce(lambda count, x: count + len(x) - 1, res, 0), 2))
42.     mi = 0
43.     for j in range(len(res)):
44.         for i in range(1, len(res[j])):
45.             matrix[mi, 0] = res[j]
46.             matrix[mi, 1] = res[j][i]
47.             mi += 1
48.
49.     return matrix[0:mi,:]
50.
51. min_iter = 5000
52. max_iter = 10000
53. step_iter = 1000
54.
55. rng = range(min_iter, max_iter, step_iter)
56. elapsed_naive = np.zeros(len(rng))
57. elapsed_kd_sn = np.zeros(len(rng))
58. elapsed_kd_tr = np.zeros(len(rng))
59. shapes = np.zeros((len(rng), 3))
60. ei = 0
61. for i in rng:
62.     random_cells = np.random.rand(i, 3) * 400.
63.     t = time()
64.     r1 = find_nn_naive(random_cells, 50.)
65.     shapes[ei, 0] = r1.shape
66.     elapsed_naive[ei] = time() - t
67.     # print('naive shape:', r1.shape)
68.     # print('naive time:', elapsed_naive[ei])
69.     t = time()
70.     r2 = find_nn_kd_seminaive(random_cells, 50.)
71.     shapes[ei, 1] = r2.shape
72.     elapsed_kd_sn[ei] = time() - t
73.     # print('seminaive shape:', r2.shape)
74.     # print('seminaive time:', elapsed_kd_sn[ei])
75.     t = time()
76.     r3 = find_nn_kd_by_tree(random_cells, 50.)
77.     shapes[ei, 2] = r3.shape
78.     elapsed_kd_tr[ei] = time() - t
79.     # print('tree shape:', r3.shape)
80.     # print('tree time:', elapsed_kd_tr[ei])
81.     # print(r1,r2,r3)
82.     # exit()
83.     ei += 1
84.
85. # Plot result comparison: Do all 3 implementations yield the same result? -> 3 overlapping lines
86. plt.plot(rng, shapes[:,0], label='naive')
87. plt.plot(rng, shapes[:,1], label='semi kd')
88. plt.plot(rng, shapes[:,2], label='full kd')
89. plt.legend()
90. plt.show(block=True)
91.
92. # What's the runtime for each?
93. plt.plot(rng, elapsed_naive, label='naive')
94. plt.plot(rng, elapsed_kd_sn, label='semi kd')
95. plt.plot(rng, elapsed_kd_tr, label='full kd')
96. plt.legend()
97. plt.show(block=True)
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy.

Top