Advertisement
Guest User

Untitled

a guest
Jul 5th, 2015
187
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.27 KB | None | 0 0
  1. import numpy as np
  2.  
  3. import numba as nb
  4.  
  5. @nb.njit()
  6. def partition(values, idxs, left, right):
  7. """
  8. Partition method
  9. """
  10.  
  11. piv = values[idxs[left]]
  12. i = left + 1
  13. j = right
  14.  
  15. while True:
  16. while i <= j and values[idxs[i]] <= piv:
  17. i += 1
  18. while j >= i and values[idxs[j]] >= piv:
  19. j -= 1
  20. if j <= i:
  21. break
  22.  
  23. idxs[i], idxs[j] = idxs[j], idxs[i]
  24.  
  25. idxs[left], idxs[j] = idxs[j], idxs[left]
  26.  
  27. return j
  28.  
  29.  
  30. @nb.njit()
  31. def argsort1D(values):
  32.  
  33. idxs = np.arange(values.shape[0])
  34.  
  35. left = 0
  36. right = values.shape[0] - 1
  37.  
  38. max_depth = np.int(right / 2)
  39.  
  40. ndx = 0
  41.  
  42. tmp = np.zeros((max_depth, 2), dtype=np.int64)
  43.  
  44. tmp[ndx, 0] = left
  45. tmp[ndx, 1] = right
  46.  
  47. ndx = 1
  48. while ndx > 0:
  49.  
  50. ndx -= 1
  51. right = tmp[ndx, 1]
  52. left = tmp[ndx, 0]
  53.  
  54. piv = partition(values, idxs, left, right)
  55.  
  56. if piv - 1 > left:
  57. tmp[ndx, 0] = left
  58. tmp[ndx, 1] = piv - 1
  59. ndx += 1
  60.  
  61. if piv + 1 < right:
  62. tmp[ndx, 0] = piv + 1
  63. tmp[ndx, 1] = right
  64. ndx += 1
  65.  
  66. return idxs
  67.  
  68.  
  69. if __name__ == '__main__':
  70. x = np.random.random((100000,))
  71.  
  72. res = np.argsort(x)
  73. out = argsort1D(x)
  74.  
  75. assert np.all(res == out)
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement