# Union/Intersection

ploffie Nov 16th, 2012 57 Never
1. import heapq
2.
3. def inter1(a, b):
4.     """Loop over a and b and check if there are doubles"""
5.     return [ai for ai in a for bi in b if ai == bi]
6.
7. def union1(a, b):
8.     """Loop over the elements of a+b"""
9.     res = []
10.     for x in a + b:
11.         if x not in res:
12.             res.append(x)
13.     return res
14.
15. def inter2(a, b):
16.     """Loop over a and check if present in b"""
17.     return [ai for ai in a if ai in b]
18.
19. def union2(a, b):
20.     """Output a + Loop over b and check if present in a"""
21.     return a + [bi for bi in b if bi not in a]
22.
23. def inter3(a, b):
24.     """Output a after removing all elements in b"""
25.     res = list(a)
26.     for ai in a:
27.         if ai not in b: res.remove(ai)
28.     return res
29.
30. def inter4(a, b):
31.     """Make heaps of a and b. Copy to output elements that are in a and b"""
32.     ha = list(a)
33.     hb = list(b)
34.     res = []
35.     heapq.heapify(ha)
36.     heapq.heapify(hb)
37.     while ha and hb:
38.         if ha[0] < hb[0]:
39.             heapq.heappop(ha)
40.         elif ha[0] > hb[0]:
41.             heapq.heappop(hb)
42.         else:
43.             res.append(heapq.heappop(ha))
44.             heapq.heappop(hb)
45.     return res
46.
47. def union4(a, b):
48.     """Make heaps of a and b. Copy to output elements that are in a or b"""
49.     ha = list(a)
50.     hb = list(b)
51.     res = []
52.     heapq.heapify(ha)
53.     heapq.heapify(hb)
54.     while ha and hb:
55.         if ha[0] < hb[0]:
56.             res.append(heapq.heappop(ha))
57.         elif ha[0] > hb[0]:
58.             res.append(heapq.heappop(hb))
59.         else:
60.             res.append(heapq.heappop(ha))
61.             heapq.heappop(hb)
62.     return res + ha + hb # either ha or hb can contain a tail
63.
64. def inter99(a, b):
65.     """Make set of a and b. Output list of intersection."""
66.     return list(set(a) & set(b))
67.
68. def union99(a, b):
69.     """Make set of a and b. Output list of union."""
70.     return list(set(a) | set(b))
71.
72. def union99a(a, b):
73.     """Output list of set of a + b."""
74.     return list(set(a + b))
75.
76. if __name__ == '__main__':
77.     from time import clock
78.     from random import randint
79.     from collections import defaultdict
80.     import math
81.
82.     def ranarr(n, N):
83.         data = range(N)
84.         res = []
85.         for i in range(n):
86.             res.append(data.pop(randint(0, N - i - 1)))
87.         return res
88.
89.     timings = defaultdict(list)
90.     def tim(method, a, b, n=100):
91.         name = method.__name__
92.         t0 = clock()
93.         for i in xrange(n):
94.             res = method(a, b)
95.         ti = clock() - t0
96.         timings[name].append((len(a), ti))
97.
98.     Ns = [10, 20, 40, 80, 160, 320, 640, 1280, 2560, 5120, 10240]
99.     a_arr, b_arr = {}, {}
100.     for N in Ns:
101.         a_arr[N] = ranarr(N, 3 * N)
102.         b_arr[N] = ranarr(N, 3 * N)
103.
104.
105.     intersections = [inter1, inter2, inter3, inter4, inter99]
106.     unions = [union1, union2, union4, union99, union99a]
107.     for method in intersections + unions:
108.         for N in Ns:
109.             a = a_arr[N]
110.             b = b_arr[N]
111.             tim(method, a, b, 1)
112.
113.     def fit_timings(model, t):
114.         a = sum(ti * model(ni) for ni,ti in t) / sum(model(ni) ** 2 for ni,ti in t)
115.         err2 = sum((a * model(ni) - ti) ** 2 for ni, ti in t)
116.         return a, err2
117.
118.     def check_models(t):
119.         res = {}
120.         besterr = float("inf")
121.         mod_n = lambda n: n
122.         mod_n2 = lambda n: n ** 2
123.         mod_nlogn = lambda n: n * math.log(n)
124.         n2 = lambda n: n ** 2
125.         for model, mname in [(mod_n, "O(n)") , (mod_n2, "O(n2)"), (mod_nlogn, "O(nlogn)")]:
126.             coeff, err2 = fit_timings(model, t)
127.             res[mname] = err2
128.             if err2 < besterr:
129.                 bestmethod = mname
130.                 besterr = err2
131.         serr, altmethod = min((v, key) for key, v in res.iteritems() if key != bestmethod)
132.         ratio = serr / besterr
133.         return bestmethod, coeff, ratio, altmethod, res
134.
135.     for method in intersections + unions:
136.         name = method.__name__
137.         doc = method.__doc__
138.         times = timings[name]
139.         bestmethod, coeff, ratio, altmethod, fiterr = check_models(times)
140.         print "{0:8s} {1:8s} coeff: {2:6.2e}  ratio:{3:6.0f}   alt: {4:8s} "\
141.             "{5:s}".format(name, bestmethod, coeff, ratio, altmethod, doc)
142.
143. """
144. inter1   O(n2)    coeff: 8.78e-05  ratio:   704   alt: O(nlogn) Loop over a and b and check if there are doubles
145. inter2   O(n2)    coeff: 3.12e-05  ratio:  1958   alt: O(nlogn) Loop over a and check if present in b
146. inter3   O(n2)    coeff: 3.47e-05  ratio:   510   alt: O(nlogn) Output a after removing all elements in b
147. inter4   O(nlogn) coeff: 1.03e-05  ratio:    34   alt: O(n)     Make heaps of a and b. Copy to output elements that are in a and b
148. inter99  O(n)     coeff: 2.71e-07  ratio:     5   alt: O(nlogn) Make set of a and b. Output list of intersection.
149. union1   O(n2)    coeff: 5.83e-05  ratio:  7656   alt: O(nlogn) Loop over the elements of a+b
150. union2   O(n2)    coeff: 3.06e-05  ratio:  1225   alt: O(nlogn) Output a + Loop over b and check if present in a
151. union4   O(nlogn) coeff: 1.09e-05  ratio:     3   alt: O(n)     Make heaps of a and b. Copy to output elements that are in a or b
152. union99  O(n)     coeff: 2.94e-07  ratio:     4   alt: O(nlogn) Make set of a and b. Output list of union.
153. union99a O(nlogn) coeff: 2.30e-07  ratio:     1   alt: O(n)     Output list of set of a + b.
154. """
