Emania

AVL

Dec 28th, 2018
131
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 9.63 KB | None | 0 0
  1. from graphviz import Digraph
  2. from math import floor
  3. from time import sleep
  4. import os
  5.  
  6.  
  7.  
  8.  
  9. ############
  10. # GRAPHVIZ #
  11. ############
  12.  
  13. def show(view=False, name="avl"):
  14.     if ROOT:
  15.         dot = Digraph(comment='TREE')
  16.         dot_write(dot, ROOT.left, label="L")
  17.         dot_write(dot, ROOT.right, label="R")
  18.         dot.attr(label=name)
  19.         dot.render('C:/Users/jan/Documents/Projects/avl/{}.gv'.format("avl"), view=view)
  20.         if view:
  21.             sleep(0.5)
  22.  
  23. def dot_write(dot, node, label="x"):
  24.     if node is not None:
  25.         node_label = str(node.key) if not node.deleted else str(node.key) + "x"
  26.         dot.node(str(node.key), node_label)
  27.         dot.edge(str(node.parent.key), str(node.key), label)
  28.         dot_write(dot, node.left, label="L")
  29.         dot_write(dot, node.right, label="R")
  30.  
  31.  
  32. ########
  33. # NODE #
  34. ########
  35.  
  36. def tree_height():
  37.     return ROOT.height if ROOT else -1
  38.  
  39. class Node:
  40.     key = None
  41.     deleted = False
  42.     height = 1
  43.     left = None
  44.     right = None
  45.     parent = None
  46.     balance = 0
  47.     deleted_upto_here = 0
  48.  
  49.     def __init__(self, key):
  50.         self.key = key
  51.  
  52.     def __str__(self):
  53.         l = self.left.key if self.left else "x"
  54.         r = self.right.key if self.right else "x"
  55.         par = self.parent.key if self.parent else "x"
  56.         d = "x" if self.deleted else ""
  57.         return "{:2d} h{}, p{}, d{}, b{:2d}, l:{}, r:{} {}".format(
  58.             self.key, self.height, par, self.deleted_upto_here, self.balance, l, r, d)
  59.  
  60. #######
  61. # AVL #
  62. #######
  63.  
  64. ROOT = None
  65. CONSOLIDATIONS = 0
  66. ROTATIONS = 0
  67.  
  68. def avl_insert(key):
  69.     global ROOT
  70.     if not ROOT:
  71.         ROOT = Node(key)
  72.         return ROOT
  73.  
  74.     node = ROOT
  75.     while True:
  76.         if key == node.key:
  77.             return
  78.         if key < node.key and node.left is not None:
  79.             node = node.left
  80.             continue
  81.         if key > node.key and node.right is not None:
  82.             node = node.right
  83.             continue
  84.         break
  85.  
  86.     new_node = Node(key)
  87.     new_node.parent = node
  88.     if key > node.key:
  89.         node.right = new_node
  90.     if key < node.key:
  91.         node.left = new_node
  92.     propagate(new_node)
  93.     return new_node
  94.  
  95. def rotate_right(b):
  96.     global ROOT, ROTATIONS
  97.     ROTATIONS += 1
  98.     parent = b.parent
  99.     a = b.left
  100.     mid = a.right
  101.  
  102.     a.right = b
  103.     b.parent = a
  104.  
  105.     if mid:
  106.         b.left = mid
  107.         mid.parent = b
  108.     else:
  109.         b.left = None
  110.  
  111.     if parent:
  112.         a.parent = parent
  113.         if a.key < parent.key:
  114.             parent.left = a
  115.         if a.key > parent.key:
  116.             parent.right = a
  117.     else:
  118.         a.parent = None
  119.         ROOT = a
  120.     # propagate(b)
  121.     set_height_balance(b)
  122.     set_height_balance(a)
  123.     set_height_balance(a.parent)
  124.     return a
  125.  
  126.  
  127. def rotate_left(a):
  128.     global ROOT, ROTATIONS
  129.     ROTATIONS += 1
  130.     parent = a.parent
  131.     b = a.right
  132.     mid = b.left
  133.  
  134.     b.left = a
  135.     a.parent = b
  136.  
  137.     if mid:
  138.         a.right = mid
  139.         mid.parent = a
  140.     else:
  141.         a.right = None
  142.  
  143.     if parent:
  144.         b.parent = parent
  145.         if b.key < parent.key:
  146.             parent.left = b
  147.         if b.key > parent.key:
  148.             parent.right = b
  149.     else:
  150.         b.parent = None
  151.         ROOT = b
  152.  
  153.     # propagate(a)
  154.     set_height_balance(a)
  155.     set_height_balance(b)
  156.     # set_height_balance(b.parent)
  157.     return b
  158.  
  159. def solve_balance(node):
  160.     balance = node.balance
  161.     if balance >= 2:
  162.         if node.left.balance >= 0:
  163.             # ++ -> R
  164.             rotate_right(node)
  165.         else:
  166.             # +- -> LR
  167.             rotate_left(node.left)
  168.             rotate_right(node)
  169.     elif balance <= -2:
  170.         if node.right.balance <= 0:
  171.             # -- -> L
  172.             rotate_left(node)
  173.         else:
  174.             # -+ -> RL
  175.             rotate_right(node.right)
  176.             rotate_left(node)
  177.     return node
  178.  
  179. def find_key(key):
  180.     if not ROOT:
  181.         return None
  182.     node = ROOT
  183.     while True:
  184.         if not node:
  185.             return None
  186.         if key == node.key:
  187.             return node
  188.         elif key < node.key:
  189.             node = node.left
  190.         elif key > node.key:
  191.             node = node.right
  192.  
  193.  
  194.  
  195. def set_height_balance(node):
  196.     if node is not None:
  197.         left_height = -1 if node.left is None else node.left.height
  198.         right_height = -1 if node.right is None else node.right.height
  199.         node.height = 1 + max(left_height, right_height)
  200.         node.balance = left_height - right_height
  201.         assert node.balance is not None
  202.  
  203. def set_deleted_upto_here(node):
  204.     if node is not None:
  205.         left = node.left.deleted_upto_here if node.left else 0
  206.         right = node.right.deleted_upto_here if node.right else 0
  207.         node.deleted_upto_here = max(left, right) + 1 if node.deleted else max(left, right)
  208.  
  209. def propagate_delete(node):
  210.     while(node):
  211.         set_deleted_upto_here(node)
  212.         node = node.parent
  213.  
  214. def propagate(node):
  215.     if node is not None:
  216.         set_height_balance(node)
  217.         set_deleted_upto_here(node)
  218.         if node.balance <= -2 or 2 <= node.balance:
  219.             solve_balance(node)
  220.         propagate(node.parent)
  221.  
  222. def print_tree(node=ROOT):
  223.     if node:
  224.         print_tree(node.left)
  225.         print(node)
  226.         print_tree(node.right)
  227.  
  228. def find_max_key_in_subtree(node):
  229.     if not node.right:
  230.         return node
  231.     else:
  232.         return find_max_key_in_subtree(node.right)
  233.  
  234. def find_min_key_in_subtree(node):
  235.     if not node.left:
  236.         return node
  237.     else:
  238.         return find_min_key_in_subtree(node.left)
  239.  
  240. def remove_leaf(node):
  241.     global ROOT
  242.     assert not node.left and not node.right, "{} is not leaf".format(node)
  243.     if not node.parent:
  244.         ROOT = None
  245.         return
  246.     if node.parent.key > node.key:  # left
  247.         node.parent.left = None
  248.     else:  # right
  249.         node.parent.right = None
  250.     propagate(node.parent)
  251.     # del node
  252.  
  253.  
  254. def avl_remove(node):
  255.     if node:
  256.         if not node.left and not node.right: # is leaf
  257.             remove_leaf(node)
  258.         elif node.left:
  259.             tmp = find_max_key_in_subtree(node.left)
  260.             assert tmp
  261.             new_key = tmp.key
  262.             avl_remove(tmp)
  263.             node.key = new_key
  264.             node.deleted = False
  265.         elif node.right:
  266.             tmp = find_min_key_in_subtree(node.right)
  267.             assert tmp
  268.             new_key = tmp.key
  269.             avl_remove(tmp)
  270.             node.key = new_key
  271.             node.deleted = False
  272.  
  273. def should_consolidate():
  274.     height = tree_height()
  275.     longest_delete_path = ROOT.deleted_upto_here
  276.     return longest_delete_path >= 1 + floor(height / 2.0)
  277.  
  278.  
  279. def test_tree():
  280.     global ROOT
  281.     ROOT = Node(0)
  282.     n0 = ROOT
  283.     n5 = avl_insert(5)
  284.     nm5 = avl_insert(-5)
  285.     n10 = avl_insert(10)
  286.     n16 = avl_insert(16)
  287.     n18 = avl_insert(18)
  288.     n20 = avl_insert(20)
  289.     n22 = avl_insert(22)
  290.  
  291.     show(False)
  292.  
  293.     # find_key(18)
  294.     delete(22)
  295.     delete(18)
  296.     delete(20)
  297.     delete(18)
  298.  
  299.  
  300.  
  301.     # print(n15.parent)
  302.  
  303.  
  304.     print("TREE HEIGHT {}".format(tree_height()))
  305.     print("DONE")
  306.  
  307.  
  308. #######
  309. # ALG #
  310. #######
  311.  
  312. def delete(key):
  313.     node = find_key(key)
  314.     if node is ROOT:
  315.         pass
  316.     if node:
  317.         if not node.deleted:
  318.             node.deleted = True
  319.             propagate_delete(node)
  320.     consolidate()
  321.  
  322. def insert(key):
  323.     # this function can be void. Returns only for debugging
  324.     node = find_key(key)
  325.     if node:
  326.         node.deleted = False
  327.         propagate_delete(node)
  328.         return node
  329.     else:
  330.         return avl_insert(key)
  331.     consolidate()
  332.  
  333. def find_first_node(node):
  334.     if not node:
  335.         return False
  336.     l = find_first_node(node.left)
  337.     if l:
  338.         return l
  339.     r = find_first_node(node.right)
  340.     if r:
  341.         return r
  342.     return node if node.deleted else False
  343.  
  344.  
  345. def consolidate():
  346.     if should_consolidate():
  347.         global CONSOLIDATIONS, ROOT
  348.         CONSOLIDATIONS += 1
  349.         to_delete = find_first_node(ROOT)
  350.         while to_delete:
  351.             avl_remove(to_delete)
  352.             to_delete = find_first_node(ROOT)
  353.  
  354.  
  355. def delete_order(node):
  356.     if node is None:
  357.         return []
  358.     else:
  359.         if node.deleted:
  360.             return delete_order(node.left) + delete_order(node.right) + [node.key]
  361.         else:
  362.             return delete_order(node.left) + delete_order(node.right)
  363.  
  364. def parse(i):
  365.     DIR = "C:/Users/jan/Documents/Projects/oblak/"
  366.     name_in = "{}pub{}.in".format(DIR, str(i).zfill(2))
  367.     name_out = "{}pub{}.out".format(DIR, str(i).zfill(2))
  368.     with open(name_in, "r") as f:
  369.         data = f.readlines()
  370.         data = [int(e) for e in data[1:]]
  371.         print(data)
  372.     r0, r1, r2 = run(data)
  373.     with open(name_out, "r") as f:
  374.         data = f.read()
  375.         data = [int(e) for e in data.split(" ")]
  376.         print("Correct: {}".format(data))
  377.         if data[0] == r0 and data[1] == r1 and data[2] == r2:
  378.             print("== CORRECT ==")
  379.         else:
  380.             print("== WRONG ==")
  381.  
  382.  
  383. def run(nums):
  384.     global ROOT, ROTATIONS, CONSOLIDATIONS
  385.     ROTATIONS = CONSOLIDATIONS = 0
  386.     ROOT = None
  387.     for i, num in enumerate(nums):
  388.         # print("processing {:3d} / {:3d}: {}".format(i, len(nums), num))
  389.         if num > 0:
  390.             insert(num)
  391.         else:
  392.             delete(-num)
  393.         show(False, str(i))
  394.     print("\nresult:  {}".format([tree_height(), ROTATIONS, CONSOLIDATIONS]))
  395.     return tree_height(), ROTATIONS, CONSOLIDATIONS
  396.  
  397. def test_easy():
  398.     for i in range(1, 10):
  399.         print("DATASET {}".format(i))
  400.         parse(i)
  401.  
  402. # parse(2)
  403. test_easy()
  404. print("DONE.")
  405. exit()
Advertisement
Add Comment
Please, Sign In to add comment