Guest User

Quickselect

a guest
May 28th, 2014
329
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. def select(xs, k):
  2.     """ Returns sorted(xs)[k] using expected linear time """
  3.     # Inv: xs[:a] <= res < xs[b:]
  4.     a, b = 0, len(xs)
  5.     while a + 1 != b:
  6.         x = choose_pivot(xs, a, b)
  7.         # Partition xs such that xs[:i] < x, xs[i:j] = x, xs[j:] > x
  8.         i, j = partition3(xs, a, b, x, x+1)
  9.         if k < i:
  10.             b = i
  11.         # If i <= res < j, res must be equal to the pivot
  12.         elif k < j:
  13.             a, b = i, i+1
  14.         else:
  15.             a = j
  16.     return xs[a]
  17.  
  18. def select2(xs, a, b, k):
  19.     """ Dual pivot, ala Java 7 quick sort, version of select """
  20.     if a + 1 == b:
  21.         return xs[a]
  22.     x, y = choose_pivots(xs, a, b)
  23.     i, j = partition3(xs, a, b, x, y)
  24.     if k < i:
  25.         return select2(xs, a, i, k)
  26.     if k >= j:
  27.         return select2(xs, j, b, k)
  28.     # If all the values in the middle segment are equal,
  29.     # we have to return early to ensure termination.
  30.     if x + 1 == y:
  31.         return x
  32.     return select2(xs, i, j, k)
  33.  
  34. def select3(xs, a, b, k):
  35.     """ Selects the kth and k+1th element and returns their average.
  36.        k must be less than len(xs) """
  37.     if a + 1 == b:
  38.         return xs[a]
  39.     # We use a single pivot here for simplicity
  40.     x = choose_pivot(xs, a, b)
  41.     i, j = partition3(xs, a, b, x, x+1)
  42.     if k+1 < i:
  43.         return select3(xs, a, i, k)
  44.     if k+1 == i:
  45.         # If we overlap the middle segment halfway from below
  46.         return (max(xs[a:i]) + x)/2.
  47.     if k+1 < j:
  48.         return x
  49.     if k+1 == j:
  50.         # If we overlap the middle segment halfway from above
  51.         return (x + min(xs[j:b]))/2.
  52.     return select3(xs, j, b, k)
  53.  
  54. def partition3(xs, a, b, x, y):
  55.     """ Post cond: xs[a:i] < x <= xs[i:j] < y <= xs[j:b]
  56.        Inv: xs[a:i] < x <= xs[i:j] < y <= xs[k:b] """
  57.     assert 0 <= a < b <= len(xs) and x < y
  58.     i, j, k = a, a, b
  59.     while j != k:
  60.         if xs[j] < x:
  61.             xs[i], xs[j] = xs[j], xs[i]
  62.             i, j = i+1, j+1
  63.         elif xs[j] < y:
  64.             j = j+1
  65.         else:
  66.             xs[j], xs[k-1] = xs[k-1], xs[j]
  67.             k = k-1
  68.     return i, j
  69.  
  70. def choose_pivot(xs, a, b):
  71.     """ Chooses a single pivot """
  72.     x = median_of_three(xs[a], xs[(a+b)//2], xs[b-1])
  73.     return x
  74.  
  75. def median_of_three(x, y, z):
  76.     """ Median of three for stability """
  77.     return x + y + z - (min(x,y,z) + max(x,y,z))
  78.  
  79. def choose_pivots(xs, a, b):
  80.     """ Chooses two pivots, similarly to how Java 7 does it """
  81.     third = (b-a)//3
  82.     x, y = xs[a+third], xs[b-1-third]
  83.     if x < y: return x, y
  84.     if x == y: return x, x+1
  85.     return y, x
RAW Paste Data