Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class Node:
- def __init__(self, key, value, balance=0, left=None,right=None):
- self.key = key
- self.value = value
- self.balance = balance
- self.left = left
- self.right = right
- class AVL:
- def __init__(self):
- self.root = None
- self.size = 0
- def put(self, key, value):
- self.root, growth = self._put(self.root, key, value)
- def remove(self, key):
- self.root, growth, found = self._remove(self.root, key)
- return found
- def _put(self, pointer, key, value):
- if pointer is None:
- self.size += 1
- return Node(key, value), 1
- if key < pointer.key:
- pointer.left, child_growth = self._put(pointer.left, key, value)
- previous_balance = pointer.balance
- pointer.balance -= child_growth
- growth = 1 if previous_balance <= 0 and pointer.balance < previous_balance else 0 # make it more generic
- elif key > pointer.key:
- pointer.right, child_growth = self._put(pointer.right, key, value)
- previous_balance = pointer.balance
- pointer.balance += child_growth
- growth = 1 if previous_balance >= 0 and pointer.balance > previous_balance else 0 # make it more generic
- else:
- pointer.key, pointer.value = key, value
- rotated_pointer, rotation_growth = self._rotate(pointer)
- return rotated_pointer, growth + rotation_growth
- def _remove(self, pointer, key):
- if pointer is None:
- return None, 0, False
- if key < pointer.key:
- pointer.left, child_growth, found = self._remove(pointer.left, key)
- previous_balance = pointer.balance
- pointer.balance -= child_growth
- growth = 0 if previous_balance >= 0 or pointer.balance <= previous_balance else -1 # make it more generic
- elif key > pointer.key:
- pointer.right, child_growth, found = self._remove(pointer.right, key)
- previous_balance = pointer.balance
- pointer.balance += child_growth
- growth = 0 if previous_balance <= 0 or pointer.balance >= previous_balance else -1 # make it more generic
- else:
- if pointer.left is None and pointer.right is None:
- return None, -1, True
- elif pointer.left is not None and pointer.right is None:
- return child.left, -1, True
- elif pointer.left is None and pointer.right is not None:
- return child.right, -1, True
- else:
- # def _left_nearest(pointer):
- # pointer = pointer.left
- # while pointer.right is not None:
- # pointer = pointer.right
- # return pointer
- # dummy_key = pointer.right.key
- # nearest_pointer = _left_nearest(pointer)
- def _right_nearest(pointer):
- pointer = pointer.right
- while pointer.left is not None:
- pointer = pointer.left
- return pointer
- dummy_key = pointer.left.key
- nearest = _right_nearest(pointer)
- nearest_key, nearest_value = nearest.key, nearest.value
- nearest.key, nearest.value = pointer.key, pointer.value
- pointer.key, pointer.value = dummy_key, nearest_value
- rotated_pointer, growth, found = self._remove(pointer, key)
- pointer.key = nearest_key
- return rotated_pointer, growth, found
- rotated_pointer, rotation_growth = self._rotate(pointer)
- return rotated_pointer, growth + rotation_growth, True
- def _rotate(self, pointer):
- rotated_pointer = pointer
- rotation_growth = 0
- if pointer.balance <= -2:
- if pointer.left.balance > 0:
- rotated_pointer, growth = self._left(rotated_pointer.left)
- rotation_growth += growth
- rotated_pointer, growth = self._right(rotated_pointer)
- rotation_growth += growth
- elif pointer.balance >= 2:
- if pointer.right.balance < 0:
- rotated_pointer, growth = self._right(rotated_pointer.right)
- rotation_growth += growth
- rotated_pointer, growth = self._left(rotated_pointer)
- rotation_growth += growth
- return rotated_pointer, rotation_growth
- def _left(self, pointer):
- child = pointer.right
- pointer.right = child.left
- child.left = pointer
- growth = -1 if pointer.balance >= 2 else 0 if pointer.balance == 1 else 1
- pointer.balance = pointer.balance - 1 - max(child.balance, 0)
- child.balance = child.balance - 1 + min(pointer.balance, 0)
- return child, growth
- def _right(self, pointer):
- child = pointer.left
- pointer.left = child.right
- child.right = pointer
- growth = -1 if pointer.balance >= -2 else 0 if pointer.balance == -1 else 1
- pointer.balance = pointer.balance + 1 - min(child.balance, 0)
- child.balance = child.balance + 1 + max(pointer.balance, 0)
- return child, growth
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement