Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- struct Node {
- int height{1};
- Node *left{nullptr};
- Node *right{nullptr};
- int value{};
- };
- class AVLTree {
- Node *root_;
- int count_;
- public:
- AVLTree() {
- root_ = nullptr;
- count_ = 0;
- }
- int getHeight() {
- if (root_ == nullptr) {
- return 0;
- } else {
- return root_->height;
- }
- }
- void insert(int value) {
- root_ = insert(root_, value);
- count_++;
- }
- void erase(int value) {
- root_ = erase(root_, value);
- count_--;
- }
- int *find(int value) {
- auto node = find(root_, value);
- if (node == nullptr) {
- return nullptr;
- }
- return &node->value;
- }
- int *traversal() {
- int *array = new int[count_];
- int index = 0;
- traversal(array, &index, root_);
- return array;
- }
- int *lowerBound(int value) {
- Node *current = root_;
- int *res = nullptr;
- while (current != nullptr) {
- if (value >= current->value) {
- if (current->value >= value) {
- return &(current->value);
- }
- current = current->right;
- } else {
- res = ¤t->value;
- current = current->left;
- }
- }
- return res;
- }
- bool empty() const {
- return count_ == 0;
- }
- Node *getRoot() {
- return root_;
- }
- int getSize() const {
- return count_;
- }
- ~AVLTree() {
- clear(root_);
- }
- private:
- int height(const Node *node) const {
- if (node == nullptr) {
- return 0;
- }
- return node->height;
- }
- int bfactor(const Node *node) {
- return height(node->right) - height(node->left);
- }
- Node *insert(Node *root, int value) {
- if (root == nullptr) {
- Node *new_node = new Node;
- new_node->value = value;
- return new_node;
- }
- if (value < root->value) {
- root->left = insert(root->left, value);
- } else if (value > root->value) {
- root->right = insert(root->right, value);
- }
- return balance(root);
- }
- Node *eraseMin(Node *root) {
- if (root->left == nullptr) {
- return root->right;
- }
- root->left = eraseMin(root->left);
- return balance(root);
- }
- Node *erase(Node *root, int value) {
- if (root == nullptr) {
- return nullptr;
- }
- if (value < root->value) {
- root->left = erase(root->left, value);
- } else if (value > root->value) {
- root->right = erase(root->right, value);
- } else {
- if (root->left == nullptr && root->right == nullptr) {
- delete root;
- return nullptr;
- } else if (root->left == nullptr) {
- Node *right = root->right;
- delete root;
- Node *min_node = min(right);
- min_node->right = eraseMin(right);
- return balance(min_node);
- } else if (root->right == nullptr) {
- Node *left = root->left;
- delete root;
- return balance(left);
- } else {
- Node *right = root->right;
- Node *left = root->left;
- delete root;
- Node *min_node = min(right);
- min_node->right = eraseMin(right);
- min_node->left = left;
- return balance(min_node);
- }
- }
- return balance(root);
- }
- Node *find(Node *root, int value) {
- if (root == nullptr) {
- return nullptr;
- }
- if (value < root->value) {
- return find(root->left, value);
- } else if (value > root->value) {
- return find(root->right, value);
- } else {
- return root;
- }
- }
- static Node *min(Node *root) {
- Node *current = root->left;
- while (current != nullptr) {
- root = current;
- current = current->left;
- }
- return root;
- }
- Node *rightRotate(Node *node) {
- Node *new_root = node->left;
- node->left = new_root->right;
- new_root->right = node;
- node->height = std::max(height(node->left), height(node->right)) + 1;
- new_root->height = std::max(height(new_root->left), height(new_root->right)) + 1;
- return new_root;
- }
- Node *leftRotate(Node *node) {
- Node *new_root = node->right;
- node->right = new_root->left;
- new_root->left = node;
- node->height = std::max(height(node->left), height(node->right)) + 1;
- new_root->height = std::max(height(new_root->left), height(new_root->right)) + 1;
- return new_root;
- }
- Node *balance(Node *node) {
- node->height = std::max(height(node->left), height(node->right)) + 1;
- if (bfactor(node) == 2) {
- if (bfactor(node->right) < 0) {
- node->right = rightRotate(node->right);
- }
- return leftRotate(node);
- }
- if (bfactor(node) == -2) {
- if (bfactor(node->left) > 0) {
- node->left = leftRotate(node->left);
- }
- return rightRotate(node);
- }
- return node;
- }
- void traversal(int *array, int *index, Node *node) {
- if (node == nullptr) {
- return;
- }
- traversal(array, index, node->left);
- array[(*index)++] = node->value;
- traversal(array, index, node->right);
- }
- void clear(Node *node) {
- if (node == nullptr) {
- return;
- }
- clear(node->left);
- clear(node->right);
- delete node;
- }
- };
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement