daily pastebin goal
4%
SHARE
TWEET

Untitled

a guest Oct 16th, 2018 57 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. # experiment.py
  2. import numpy as np
  3. import time
  4. from scipy.spatial.distance import cdist
  5.  
  6.  
  7. def cdist_fast(XA, XB):
  8.     XA_norm = np.sum(XA**2, axis=1)
  9.     XB_norm = np.sum(XB**2, axis=1)
  10.     XA_XB_T = np.dot(XA, XB.T)
  11.     distances = XA_norm.reshape(-1,1) + XB_norm - 2*XA_XB_T
  12.     return distances
  13.  
  14.  
  15. def main():
  16.     M,N = 5000, 128
  17.     XA = np.random.randn(M,N)
  18.  
  19.     t = time.time()
  20.     distances_cdist = cdist(XA, XA, metric='sqeuclidean')
  21.     time_cdist = time.time() - t
  22.  
  23.     t = time.time()
  24.     distances_cdist_fast = cdist_fast(XA, XA)
  25.     time_cdist_fast = time.time() - t
  26.  
  27.     print(f'time_cdist = {time_cdist:.3f} s')
  28.     print(f'time_cdist_fast = {time_cdist_fast:.3f} s')
  29.  
  30.     # check validity of results
  31.     assert np.allclose(distances_cdist, distances_cdist_fast)
  32.  
  33.     # check that the results are non-negative
  34.     try:
  35.         assert (distances_cdist >= 0.0).all()
  36.     except AssertionError:
  37.         print('Numerical instability in cdist()')
  38.  
  39.     try:
  40.         assert (distances_cdist_fast >= 0.0).all()
  41.     except AssertionError:
  42.         print('Numerical instability in cdist_fast()')
  43.  
  44.  
  45. if __name__ == '__main__':
  46.     main()
  47.    
  48. $ python experiment.py
  49. time_cdist = 3.457 s
  50. time_cdist_fast = 0.625 s
  51. Numerical instability in cdist_fast()
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top