Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # experiment.py
- import numpy as np
- import time
- from scipy.spatial.distance import cdist
- def cdist_fast(XA, XB):
- XA_norm = np.sum(XA**2, axis=1)
- XB_norm = np.sum(XB**2, axis=1)
- XA_XB_T = np.dot(XA, XB.T)
- distances = XA_norm.reshape(-1,1) + XB_norm - 2*XA_XB_T
- return distances
- def main():
- M,N = 5000, 128
- XA = np.random.randn(M,N)
- t = time.time()
- distances_cdist = cdist(XA, XA, metric='sqeuclidean')
- time_cdist = time.time() - t
- t = time.time()
- distances_cdist_fast = cdist_fast(XA, XA)
- time_cdist_fast = time.time() - t
- print(f'time_cdist = {time_cdist:.3f} s')
- print(f'time_cdist_fast = {time_cdist_fast:.3f} s')
- # check validity of results
- assert np.allclose(distances_cdist, distances_cdist_fast)
- # check that the results are non-negative
- try:
- assert (distances_cdist >= 0.0).all()
- except AssertionError:
- print('Numerical instability in cdist()')
- try:
- assert (distances_cdist_fast >= 0.0).all()
- except AssertionError:
- print('Numerical instability in cdist_fast()')
- if __name__ == '__main__':
- main()
- $ python experiment.py
- time_cdist = 3.457 s
- time_cdist_fast = 0.625 s
- Numerical instability in cdist_fast()
Add Comment
Please, Sign In to add comment