Guest User

Untitled

a guest
Oct 16th, 2018
88
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.23 KB | None | 0 0
  1. # experiment.py
  2. import numpy as np
  3. import time
  4. from scipy.spatial.distance import cdist
  5.  
  6.  
  7. def cdist_fast(XA, XB):
  8. XA_norm = np.sum(XA**2, axis=1)
  9. XB_norm = np.sum(XB**2, axis=1)
  10. XA_XB_T = np.dot(XA, XB.T)
  11. distances = XA_norm.reshape(-1,1) + XB_norm - 2*XA_XB_T
  12. return distances
  13.  
  14.  
  15. def main():
  16. M,N = 5000, 128
  17. XA = np.random.randn(M,N)
  18.  
  19. t = time.time()
  20. distances_cdist = cdist(XA, XA, metric='sqeuclidean')
  21. time_cdist = time.time() - t
  22.  
  23. t = time.time()
  24. distances_cdist_fast = cdist_fast(XA, XA)
  25. time_cdist_fast = time.time() - t
  26.  
  27. print(f'time_cdist = {time_cdist:.3f} s')
  28. print(f'time_cdist_fast = {time_cdist_fast:.3f} s')
  29.  
  30. # check validity of results
  31. assert np.allclose(distances_cdist, distances_cdist_fast)
  32.  
  33. # check that the results are non-negative
  34. try:
  35. assert (distances_cdist >= 0.0).all()
  36. except AssertionError:
  37. print('Numerical instability in cdist()')
  38.  
  39. try:
  40. assert (distances_cdist_fast >= 0.0).all()
  41. except AssertionError:
  42. print('Numerical instability in cdist_fast()')
  43.  
  44.  
  45. if __name__ == '__main__':
  46. main()
  47.  
  48. $ python experiment.py
  49. time_cdist = 3.457 s
  50. time_cdist_fast = 0.625 s
  51. Numerical instability in cdist_fast()
Add Comment
Please, Sign In to add comment