Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from graphviz import Digraph
- from math import floor
- from time import sleep
- import os
- ############
- # GRAPHVIZ #
- ############
- def show(view=False, name="avl"):
- if ROOT:
- dot = Digraph(comment='TREE')
- dot_write(dot, ROOT.left, label="L")
- dot_write(dot, ROOT.right, label="R")
- dot.attr(label=name)
- dot.render('C:/Users/jan/Documents/Projects/avl/{}.gv'.format("avl"), view=view)
- if view:
- sleep(0.5)
- def dot_write(dot, node, label="x"):
- if node is not None:
- node_label = str(node.key) if not node.deleted else str(node.key) + "x"
- dot.node(str(node.key), node_label)
- dot.edge(str(node.parent.key), str(node.key), label)
- dot_write(dot, node.left, label="L")
- dot_write(dot, node.right, label="R")
- ########
- # NODE #
- ########
- def tree_height():
- return ROOT.height if ROOT else -1
- class Node:
- key = None
- deleted = False
- height = 1
- left = None
- right = None
- parent = None
- balance = 0
- deleted_upto_here = 0
- def __init__(self, key):
- self.key = key
- def __str__(self):
- l = self.left.key if self.left else "x"
- r = self.right.key if self.right else "x"
- par = self.parent.key if self.parent else "x"
- d = "x" if self.deleted else ""
- return "{:2d} h{}, p{}, d{}, b{:2d}, l:{}, r:{} {}".format(
- self.key, self.height, par, self.deleted_upto_here, self.balance, l, r, d)
- #######
- # AVL #
- #######
- ROOT = None
- CONSOLIDATIONS = 0
- ROTATIONS = 0
- def avl_insert(key):
- global ROOT
- if not ROOT:
- ROOT = Node(key)
- return ROOT
- node = ROOT
- while True:
- if key == node.key:
- return
- if key < node.key and node.left is not None:
- node = node.left
- continue
- if key > node.key and node.right is not None:
- node = node.right
- continue
- break
- new_node = Node(key)
- new_node.parent = node
- if key > node.key:
- node.right = new_node
- if key < node.key:
- node.left = new_node
- propagate(new_node)
- return new_node
- def rotate_right(b):
- global ROOT, ROTATIONS
- ROTATIONS += 1
- parent = b.parent
- a = b.left
- mid = a.right
- a.right = b
- b.parent = a
- if mid:
- b.left = mid
- mid.parent = b
- else:
- b.left = None
- if parent:
- a.parent = parent
- if a.key < parent.key:
- parent.left = a
- if a.key > parent.key:
- parent.right = a
- else:
- a.parent = None
- ROOT = a
- # propagate(b)
- set_height_balance(b)
- set_height_balance(a)
- set_height_balance(a.parent)
- return a
- def rotate_left(a):
- global ROOT, ROTATIONS
- ROTATIONS += 1
- parent = a.parent
- b = a.right
- mid = b.left
- b.left = a
- a.parent = b
- if mid:
- a.right = mid
- mid.parent = a
- else:
- a.right = None
- if parent:
- b.parent = parent
- if b.key < parent.key:
- parent.left = b
- if b.key > parent.key:
- parent.right = b
- else:
- b.parent = None
- ROOT = b
- # propagate(a)
- set_height_balance(a)
- set_height_balance(b)
- # set_height_balance(b.parent)
- return b
- def solve_balance(node):
- balance = node.balance
- if balance >= 2:
- if node.left.balance >= 0:
- # ++ -> R
- rotate_right(node)
- else:
- # +- -> LR
- rotate_left(node.left)
- rotate_right(node)
- elif balance <= -2:
- if node.right.balance <= 0:
- # -- -> L
- rotate_left(node)
- else:
- # -+ -> RL
- rotate_right(node.right)
- rotate_left(node)
- return node
- def find_key(key):
- if not ROOT:
- return None
- node = ROOT
- while True:
- if not node:
- return None
- if key == node.key:
- return node
- elif key < node.key:
- node = node.left
- elif key > node.key:
- node = node.right
- def set_height_balance(node):
- if node is not None:
- left_height = -1 if node.left is None else node.left.height
- right_height = -1 if node.right is None else node.right.height
- node.height = 1 + max(left_height, right_height)
- node.balance = left_height - right_height
- assert node.balance is not None
- def set_deleted_upto_here(node):
- if node is not None:
- left = node.left.deleted_upto_here if node.left else 0
- right = node.right.deleted_upto_here if node.right else 0
- node.deleted_upto_here = max(left, right) + 1 if node.deleted else max(left, right)
- def propagate_delete(node):
- while(node):
- set_deleted_upto_here(node)
- node = node.parent
- def propagate(node):
- if node is not None:
- set_height_balance(node)
- set_deleted_upto_here(node)
- if node.balance <= -2 or 2 <= node.balance:
- solve_balance(node)
- propagate(node.parent)
- def print_tree(node=ROOT):
- if node:
- print_tree(node.left)
- print(node)
- print_tree(node.right)
- def find_max_key_in_subtree(node):
- if not node.right:
- return node
- else:
- return find_max_key_in_subtree(node.right)
- def find_min_key_in_subtree(node):
- if not node.left:
- return node
- else:
- return find_min_key_in_subtree(node.left)
- def remove_leaf(node):
- global ROOT
- assert not node.left and not node.right, "{} is not leaf".format(node)
- if not node.parent:
- ROOT = None
- return
- if node.parent.key > node.key: # left
- node.parent.left = None
- else: # right
- node.parent.right = None
- propagate(node.parent)
- # del node
- def avl_remove(node):
- if node:
- if not node.left and not node.right: # is leaf
- remove_leaf(node)
- elif node.left:
- tmp = find_max_key_in_subtree(node.left)
- assert tmp
- new_key = tmp.key
- avl_remove(tmp)
- node.key = new_key
- node.deleted = False
- elif node.right:
- tmp = find_min_key_in_subtree(node.right)
- assert tmp
- new_key = tmp.key
- avl_remove(tmp)
- node.key = new_key
- node.deleted = False
- def should_consolidate():
- height = tree_height()
- longest_delete_path = ROOT.deleted_upto_here
- return longest_delete_path >= 1 + floor(height / 2.0)
- def test_tree():
- global ROOT
- ROOT = Node(0)
- n0 = ROOT
- n5 = avl_insert(5)
- nm5 = avl_insert(-5)
- n10 = avl_insert(10)
- n16 = avl_insert(16)
- n18 = avl_insert(18)
- n20 = avl_insert(20)
- n22 = avl_insert(22)
- show(False)
- # find_key(18)
- delete(22)
- delete(18)
- delete(20)
- delete(18)
- # print(n15.parent)
- print("TREE HEIGHT {}".format(tree_height()))
- print("DONE")
- #######
- # ALG #
- #######
- def delete(key):
- node = find_key(key)
- if node is ROOT:
- pass
- if node:
- if not node.deleted:
- node.deleted = True
- propagate_delete(node)
- consolidate()
- def insert(key):
- # this function can be void. Returns only for debugging
- node = find_key(key)
- if node:
- node.deleted = False
- propagate_delete(node)
- return node
- else:
- return avl_insert(key)
- consolidate()
- def find_first_node(node):
- if not node:
- return False
- l = find_first_node(node.left)
- if l:
- return l
- r = find_first_node(node.right)
- if r:
- return r
- return node if node.deleted else False
- def consolidate():
- if should_consolidate():
- global CONSOLIDATIONS, ROOT
- CONSOLIDATIONS += 1
- to_delete = find_first_node(ROOT)
- while to_delete:
- avl_remove(to_delete)
- to_delete = find_first_node(ROOT)
- def delete_order(node):
- if node is None:
- return []
- else:
- if node.deleted:
- return delete_order(node.left) + delete_order(node.right) + [node.key]
- else:
- return delete_order(node.left) + delete_order(node.right)
- def parse(i):
- DIR = "C:/Users/jan/Documents/Projects/oblak/"
- name_in = "{}pub{}.in".format(DIR, str(i).zfill(2))
- name_out = "{}pub{}.out".format(DIR, str(i).zfill(2))
- with open(name_in, "r") as f:
- data = f.readlines()
- data = [int(e) for e in data[1:]]
- print(data)
- r0, r1, r2 = run(data)
- with open(name_out, "r") as f:
- data = f.read()
- data = [int(e) for e in data.split(" ")]
- print("Correct: {}".format(data))
- if data[0] == r0 and data[1] == r1 and data[2] == r2:
- print("== CORRECT ==")
- else:
- print("== WRONG ==")
- def run(nums):
- global ROOT, ROTATIONS, CONSOLIDATIONS
- ROTATIONS = CONSOLIDATIONS = 0
- ROOT = None
- for i, num in enumerate(nums):
- # print("processing {:3d} / {:3d}: {}".format(i, len(nums), num))
- if num > 0:
- insert(num)
- else:
- delete(-num)
- show(False, str(i))
- print("\nresult: {}".format([tree_height(), ROTATIONS, CONSOLIDATIONS]))
- return tree_height(), ROTATIONS, CONSOLIDATIONS
- def test_easy():
- for i in range(1, 10):
- print("DATASET {}".format(i))
- parse(i)
- # parse(2)
- test_easy()
- print("DONE.")
- exit()
Advertisement
Add Comment
Please, Sign In to add comment