sweeneyde

slower with functions

Dec 19th, 2019
84
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 4.21 KB | None | 0 0
  1. import itertools
  2.  
  3. def multi_merge(*iterables, key=None, reverse=False):
  4.     """
  5.    Non-recursive mergesort-style algorithm:
  6.    Maintain a "tree of losers" of comparisons
  7.    as winners advance to the root.
  8.    (c.f. Knuth Volume 3, Chapter 5.4.1. on "Multiway Merging").
  9.  
  10.    Essentially use a heap, but instead of pushing new items to the
  11.    root and sifting down, push them directly to the leaf nodes--items
  12.    only ever move toward the root.
  13.  
  14.    Example for n=5 iterables A--E:
  15.            0
  16.        1       2
  17.      3   4   5   6
  18.     7 8  |   |   |
  19.     | |  C   D   E
  20.     A B
  21.        (shift = -3)
  22.    Always:
  23.        n leaf nodes
  24.            - get their values directly from iterators
  25.        n-1 internal nodes
  26.            - get their values from their children
  27.    """
  28.     n = len(iterables)
  29.     if n == 0:
  30.         return
  31.     if n == 1:
  32.         yield from iterables[0]
  33.         return
  34.  
  35.     tree = [object()] * (n + n - 1)
  36.     sentinel = object()
  37.  
  38.     if key is None:
  39.         def value(x): return x
  40.         if reverse:
  41.             def compare(x, y): return y < x
  42.         else:
  43.             def compare(x, y): return x < y
  44.     else:
  45.         def value(x): return x[1]
  46.         iterables = [zip(map(key, it), it) for it in iterables]
  47.         if reverse:
  48.             def compare(x, y): return y[0] < x[0]
  49.         else:
  50.             def compare(x, y): return x[0] < y[0]
  51.  
  52.     # For stability, rotate the list so that the first iterator
  53.     # moves from the smallest index to the leftmost index.
  54.     shift = n - (1 << n.bit_length())
  55.     getters = [itertools.chain(iterables[i + shift],
  56.                                itertools.repeat(sentinel)
  57.                                ).__next__
  58.                for i in range(len(iterables))]
  59.  
  60.     # initialize leaves
  61.     for i, getter in enumerate(getters):
  62.         tree[i + n - 1] = getter()
  63.  
  64.  
  65.     # general case with keys
  66.     for i in reversed(range(n - 1)):
  67.         while i < n - 1:
  68.             left = 2 * i + 1
  69.             right = left + 1
  70.             t_left = tree[left]
  71.             t_right = tree[right]
  72.             if t_right is sentinel:
  73.                 winner = left
  74.             elif t_left is sentinel:
  75.                 winner = right
  76.             elif compare(t_right, t_left):
  77.                 winner = right
  78.             else:
  79.                 winner = left
  80.             tree[i] = tree[winner]
  81.             i = winner
  82.         tree[i] = getters[i - (n - 1)]()
  83.  
  84.     # Main loop:
  85.     #   - yield the overall winner
  86.     #   - replace parents with their better children
  87.     #   - replace leaves using original iterables
  88.     while True:
  89.         t0 = tree[0]
  90.         if t0 is sentinel:
  91.             return
  92.         yield value(t0)
  93.         i = 0
  94.         while i < n - 1:
  95.             left = 2 * i + 1
  96.             right = left + 1
  97.             t_left = tree[left]
  98.             t_right = tree[right]
  99.             if t_right is sentinel:
  100.                 winner = left
  101.             elif t_left is sentinel:
  102.                 winner = right
  103.             elif compare(t_right, t_left):
  104.                 winner = right
  105.             else:
  106.                 winner = left
  107.             tree[i] = tree[winner]
  108.             i = winner
  109.         tree[i] = getters[i - (n - 1)]()
  110.  
  111.  
  112. if __name__ == "__main__":
  113.     from timeit import timeit
  114.     import random
  115.     from test import support
  116.     py_heapq = support.import_fresh_module('heapq', blocked=['_heapq'])
  117.  
  118.     print("number of iterables, (time to pure-python multi_merge / pure-python heapq merge)")
  119.  
  120.     n = 0
  121.     while True:
  122.         n = max(int(1.2*n), n+1)
  123.         t1 = t2 = 0
  124.         for _ in range(3):
  125.             # lists = [
  126.             #     range(i, 100, 10)
  127.             #     for i in range(5)
  128.             # ]
  129.             lists = [
  130.                 sorted(random.random() for _ in
  131.                        range(random.randint(200, 1000)))
  132.                 for _ in range(n)
  133.             ]
  134.             l1 = list(multi_merge(*lists, key=abs))
  135.             l2 = list(py_heapq.merge(*lists, key=abs))
  136.             assert l1 == l2, l1
  137.             t1 += timeit(lambda: list(multi_merge(*lists)), number=50)
  138.             t2 += timeit(lambda: list(py_heapq.merge(*lists)), number=50)
  139.         print(n, t1/t2, sep='\t')
Add Comment
Please, Sign In to add comment