Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- def select(xs, k):
- """ Returns sorted(xs)[k] using expected linear time """
- # Inv: xs[:a] <= res < xs[b:]
- a, b = 0, len(xs)
- while a + 1 != b:
- x = choose_pivot(xs, a, b)
- # Partition xs such that xs[:i] < x, xs[i:j] = x, xs[j:] > x
- i, j = partition3(xs, a, b, x, x+1)
- if k < i:
- b = i
- # If i <= res < j, res must be equal to the pivot
- elif k < j:
- a, b = i, i+1
- else:
- a = j
- return xs[a]
- def select2(xs, a, b, k):
- """ Dual pivot, ala Java 7 quick sort, version of select """
- if a + 1 == b:
- return xs[a]
- x, y = choose_pivots(xs, a, b)
- i, j = partition3(xs, a, b, x, y)
- if k < i:
- return select2(xs, a, i, k)
- if k >= j:
- return select2(xs, j, b, k)
- # If all the values in the middle segment are equal,
- # we have to return early to ensure termination.
- if x + 1 == y:
- return x
- return select2(xs, i, j, k)
- def select3(xs, a, b, k):
- """ Selects the kth and k+1th element and returns their average.
- k must be less than len(xs) """
- if a + 1 == b:
- return xs[a]
- # We use a single pivot here for simplicity
- x = choose_pivot(xs, a, b)
- i, j = partition3(xs, a, b, x, x+1)
- if k+1 < i:
- return select3(xs, a, i, k)
- if k+1 == i:
- # If we overlap the middle segment halfway from below
- return (max(xs[a:i]) + x)/2.
- if k+1 < j:
- return x
- if k+1 == j:
- # If we overlap the middle segment halfway from above
- return (x + min(xs[j:b]))/2.
- return select3(xs, j, b, k)
- def partition3(xs, a, b, x, y):
- """ Post cond: xs[a:i] < x <= xs[i:j] < y <= xs[j:b]
- Inv: xs[a:i] < x <= xs[i:j] < y <= xs[k:b] """
- assert 0 <= a < b <= len(xs) and x < y
- i, j, k = a, a, b
- while j != k:
- if xs[j] < x:
- xs[i], xs[j] = xs[j], xs[i]
- i, j = i+1, j+1
- elif xs[j] < y:
- j = j+1
- else:
- xs[j], xs[k-1] = xs[k-1], xs[j]
- k = k-1
- return i, j
- def choose_pivot(xs, a, b):
- """ Chooses a single pivot """
- x = median_of_three(xs[a], xs[(a+b)//2], xs[b-1])
- return x
- def median_of_three(x, y, z):
- """ Median of three for stability """
- return x + y + z - (min(x,y,z) + max(x,y,z))
- def choose_pivots(xs, a, b):
- """ Chooses two pivots, similarly to how Java 7 does it """
- third = (b-a)//3
- x, y = xs[a+third], xs[b-1-third]
- if x < y: return x, y
- if x == y: return x, x+1
- return y, x
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement