 # Quickselect

a guest
May 28th, 2014
329
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
1. def select(xs, k):
2.     """ Returns sorted(xs)[k] using expected linear time """
3.     # Inv: xs[:a] <= res < xs[b:]
4.     a, b = 0, len(xs)
5.     while a + 1 != b:
6.         x = choose_pivot(xs, a, b)
7.         # Partition xs such that xs[:i] < x, xs[i:j] = x, xs[j:] > x
8.         i, j = partition3(xs, a, b, x, x+1)
9.         if k < i:
10.             b = i
11.         # If i <= res < j, res must be equal to the pivot
12.         elif k < j:
13.             a, b = i, i+1
14.         else:
15.             a = j
16.     return xs[a]
17.
18. def select2(xs, a, b, k):
19.     """ Dual pivot, ala Java 7 quick sort, version of select """
20.     if a + 1 == b:
21.         return xs[a]
22.     x, y = choose_pivots(xs, a, b)
23.     i, j = partition3(xs, a, b, x, y)
24.     if k < i:
25.         return select2(xs, a, i, k)
26.     if k >= j:
27.         return select2(xs, j, b, k)
28.     # If all the values in the middle segment are equal,
29.     # we have to return early to ensure termination.
30.     if x + 1 == y:
31.         return x
32.     return select2(xs, i, j, k)
33.
34. def select3(xs, a, b, k):
35.     """ Selects the kth and k+1th element and returns their average.
36.        k must be less than len(xs) """
37.     if a + 1 == b:
38.         return xs[a]
39.     # We use a single pivot here for simplicity
40.     x = choose_pivot(xs, a, b)
41.     i, j = partition3(xs, a, b, x, x+1)
42.     if k+1 < i:
43.         return select3(xs, a, i, k)
44.     if k+1 == i:
45.         # If we overlap the middle segment halfway from below
46.         return (max(xs[a:i]) + x)/2.
47.     if k+1 < j:
48.         return x
49.     if k+1 == j:
50.         # If we overlap the middle segment halfway from above
51.         return (x + min(xs[j:b]))/2.
52.     return select3(xs, j, b, k)
53.
54. def partition3(xs, a, b, x, y):
55.     """ Post cond: xs[a:i] < x <= xs[i:j] < y <= xs[j:b]
56.        Inv: xs[a:i] < x <= xs[i:j] < y <= xs[k:b] """
57.     assert 0 <= a < b <= len(xs) and x < y
58.     i, j, k = a, a, b
59.     while j != k:
60.         if xs[j] < x:
61.             xs[i], xs[j] = xs[j], xs[i]
62.             i, j = i+1, j+1
63.         elif xs[j] < y:
64.             j = j+1
65.         else:
66.             xs[j], xs[k-1] = xs[k-1], xs[j]
67.             k = k-1
68.     return i, j
69.
70. def choose_pivot(xs, a, b):
71.     """ Chooses a single pivot """
72.     x = median_of_three(xs[a], xs[(a+b)//2], xs[b-1])
73.     return x
74.
75. def median_of_three(x, y, z):
76.     """ Median of three for stability """
77.     return x + y + z - (min(x,y,z) + max(x,y,z))
78.
79. def choose_pivots(xs, a, b):
80.     """ Chooses two pivots, similarly to how Java 7 does it """
81.     third = (b-a)//3
82.     x, y = xs[a+third], xs[b-1-third]
83.     if x < y: return x, y
84.     if x == y: return x, x+1
85.     return y, x
RAW Paste Data