• API
• FAQ
• Tools
• Archive
SHARE
TWEET Untitled a guest Oct 16th, 2018 57 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. # experiment.py
2. import numpy as np
3. import time
4. from scipy.spatial.distance import cdist
5.
6.
7. def cdist_fast(XA, XB):
8.     XA_norm = np.sum(XA**2, axis=1)
9.     XB_norm = np.sum(XB**2, axis=1)
10.     XA_XB_T = np.dot(XA, XB.T)
11.     distances = XA_norm.reshape(-1,1) + XB_norm - 2*XA_XB_T
12.     return distances
13.
14.
15. def main():
16.     M,N = 5000, 128
17.     XA = np.random.randn(M,N)
18.
19.     t = time.time()
20.     distances_cdist = cdist(XA, XA, metric='sqeuclidean')
21.     time_cdist = time.time() - t
22.
23.     t = time.time()
24.     distances_cdist_fast = cdist_fast(XA, XA)
25.     time_cdist_fast = time.time() - t
26.
27.     print(f'time_cdist = {time_cdist:.3f} s')
28.     print(f'time_cdist_fast = {time_cdist_fast:.3f} s')
29.
30.     # check validity of results
31.     assert np.allclose(distances_cdist, distances_cdist_fast)
32.
33.     # check that the results are non-negative
34.     try:
35.         assert (distances_cdist >= 0.0).all()
36.     except AssertionError:
37.         print('Numerical instability in cdist()')
38.
39.     try:
40.         assert (distances_cdist_fast >= 0.0).all()
41.     except AssertionError:
42.         print('Numerical instability in cdist_fast()')
43.
44.
45. if __name__ == '__main__':
46.     main()
47.
48. \$ python experiment.py
49. time_cdist = 3.457 s
50. time_cdist_fast = 0.625 s
51. Numerical instability in cdist_fast()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy.

Top