goodwish

399. Nuts & Bolts Problem

Nov 28th, 2019
89
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 5.15 KB | None | 0 0
  1. #meat
  2. quick sort  #quicksort 思想.
  3. 任选一个数, 作为 pivot, 用 compare.cmp(a, b) 来 partition nuts and bolts 数组.
  4. 记录 pivot 位置 mid, [start, mid - 1], [mid + 1, end] 分成上半区和下半区, 递归, 排序.
  5.     self.qsort(A, start, mid - 1)
  6.     self.qsort(A, mid + 1, end)
  7.  
  8. 递归出口: start >= end
  9.  
  10. 有三种方法 partition.
  11.  
  12. way 1: 用两个辅助数组帮助完成分区 partition. 类似 merge sort 里面的 merge.
  13.  
  14. 封装 partition(A, start, end):
  15. pivot = A[start] # 通常情况,
  16. pivot = A[i] if compare.cmp(A[i], B[start]) == 0; # 本题特别要求的比较查找.
  17.  
  18. low, high = [], []
  19. 比 pivot 小的数放进 low 数组, # if compare.cmp(A[i], B[start]) == -1:
  20. 比 pivot 大的数放进 high 数组, # if .. cmp() == 1:
  21. A[start: end + 1] = low + [pivot] + high # 完成分区子数组,放回原区间.
  22. mid = start + len(low)
  23. return mid, 即分区中轴线.
  24.  
  25. note, mid = start + len(low_arr), 记得加上基准位置 start.
  26.  
  27. way 2: 类似 0,1,2 三色排序.
  28.  
  29. def partition(A, start, end):
  30. 区间第一个数 A[start] 作为 pivot.
  31. 左右指针 l, r = start + 1, end;
  32. while l <= r: # 小于等于, 有等于, l 和 r 重合, 确保处理最后一个元素 r.
  33. - 小于 pivot 的数在左指针 l 前边,
  34. - 大于 pivot 的数在右指针 r 后边,
  35. 循环结束后, 交换 start, r 的数值, r 就是完成分区的中轴线, 左边小, 右边大.
  36. 返回 r , 即分区中轴线.
  37.  
  38. 优点, 知道 pivot 的位置. mid - 1, mid + 1 两边递归.
  39. 缺点, 分区有等于条件, 极端情况每次走到尾, O(n^2) time.
  40.  - i.e. [1,1,1,1,1,1,1,1,1], 每次分区只减少一个数, 不能两边均分, 大于 O(n log n) time.
  41.  
  42. def qsort(self, A, start, end):
  43.     if start >= end:
  44.         return
  45.     mid = self.partition(A, start, end)
  46.     self.qsort(A, start, mid - 1)
  47.     self.qsort(A, mid + 1, end)
  48.  
  49. def partition(self, A, start, end):
  50.     pivot = A[start]
  51.     l, r = start + 1, end
  52.     while l <= r:
  53.         if A[l] <= pivot: # l 前边小于等于 pivot
  54.             l += 1
  55.         else: # A[l] > pivot: r 后边大于 pivot
  56.             A[l], A[r] = A[r], A[l]
  57.             r -= 1
  58.     A[start], A[r] = A[r], A[start]
  59.     return r
  60.  
  61. way 3: 中间数作为 pivot, 左右向中间逼近. - 九章算法班.
  62. mid = (start + end)//2
  63. pivot = A[mid]
  64. l, r = start, end
  65. while l <= r:
  66.   while l <= r and A[l] < pivot:
  67.     l += 1
  68.   while l <= r and A[r] > pivot:
  69.     r -= 1
  70.   if l <= r:
  71.     A[l], A[r] = A[r], A[l]
  72.     l += 1
  73.     r -= 1
  74.  
  75. 最后 r 移到了 l 的前边/左边. A[start: r + 1] 小于等于中轴, A[l: end + 1] 大于等于中轴, 完成分区.
  76. 缺点: 不知道 pivot 的位置, 不知道 r, l 中间有没有 mid.
  77. 有一个迂回办法 workaround 就是用第1个数作为中轴 pivot,然后 partition 后边的,完成分区后右指针和第1个数交换,右指针的位置就是分区后的中轴线。
  78. .
  79.  
  80.  
  81.  
  82. """
  83. class Comparator:
  84.    def cmp(self, a, b)
  85. You can use Compare.cmp(a, b) to compare nuts "a" and bolts "b",
  86. if "a" is bigger than "b", it will return 1, else if they are equal,
  87. it will return 0, else if "a" is smaller than "b", it will return -1.
  88. When "a" is not a nut or "b" is not a bolt, it will return 2, which is not valid.
  89. """
  90.  
  91.  
  92. class Solution:
  93.     # @param nuts: a list of integers
  94.     # @param bolts: a list of integers
  95.     # @param compare: a instance of Comparator
  96.     # @return: nothing
  97.     def sortNutsAndBolts(self, nuts, bolts, compare):
  98.         self.q_sort(nuts, bolts, compare, 0, len(nuts) - 1)
  99.         print(nuts, bolts)
  100.    
  101.     def q_sort(self, A, B, compare, start, end):
  102.         # print(f"start {start}, end {end}")
  103.         if start >= end:
  104.             return
  105.         mid = (start + end)//2
  106.         new_mid_i = self.partition(A, B, compare, start, end)
  107.         # print(f"start {start}, mid {mid}, end {end}", "new_mid_i:", new_mid_i)
  108.  
  109.         self.q_sort(A, B, compare, start, new_mid_i - 1)
  110.         self.q_sort(A, B, compare, new_mid_i + 1, end)
  111.    
  112.     def partition(self, A, B, compare, start, end):
  113.         lo, hi, mi = [], [], []
  114.         mid = (start + end)//2
  115.         # print(A[start: end + 1])
  116.         # print(B[start: end + 1])
  117.         b_mid = B[mid]
  118.         for v in A[start: end + 1]:
  119.             if compare.cmp(v, b_mid) == -1:
  120.                 lo.append(v)
  121.             elif compare.cmp(v, b_mid) == 0:
  122.                 mi.append(v)
  123.                 a_mid = v
  124.             else:
  125.                 hi.append(v)
  126.         # print("a,b mid:", a_mid, b_mid)
  127.         # print("lo,mi,hi", lo, mi, hi)
  128.         new_mid_i = start + len(lo)
  129.         i = start
  130.         for v in lo + mi + hi:
  131.             A[i] = v
  132.             i += 1
  133.  
  134.         lo, hi, mi = [], [], []
  135.         for v in B[start: end + 1]:
  136.             if compare.cmp(a_mid, v) == 1:
  137.                 lo.append(v)
  138.             elif compare.cmp(a_mid, v) == 0:
  139.                 mi.append(v)
  140.             else:
  141.                 hi.append(v)
  142.         # print("lo,mi,hi", lo, mi, hi)
  143.         i = start
  144.         for v in lo + mi + hi:
  145.             B[i] = v
  146.             i += 1
  147.         # print("after partition:")
  148.         # print(A[start: end + 1])
  149.         # print(B[start: end + 1])
  150.         return new_mid_i
Add Comment
Please, Sign In to add comment