Advertisement
_takumi

c6b

Dec 19th, 2022
904
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 5.23 KB | None | 0 0
  1. #include <iostream>
  2. #include <algorithm>
  3.  
  4. struct Node {
  5.     int value;
  6.     int height;
  7.     struct Node *left;
  8.     struct Node *right;
  9. };
  10.  
  11. class AVLTree {
  12. private:
  13.     Node *root_;
  14.     int size_;
  15.  
  16.     void clear(Node *n) {
  17.         if (n == nullptr) {
  18.             return;
  19.         }
  20.         clear(n->left);
  21.         clear(n->right);
  22.         delete n;
  23.     }
  24.  
  25.     int height(Node *n) {
  26.         if (n == nullptr) {
  27.             return 0;
  28.         }
  29.         return n->height;
  30.     }
  31.  
  32.     void traversal(Node *n, int *p, int *i) {
  33.         if (n == nullptr) {
  34.             return;
  35.         }
  36.         traversal(n->left, p, i);
  37.         *(p + *i) = n->value;
  38.         (*i)++;
  39.         traversal(n->right, p, i);
  40.     }
  41.  
  42.     int balance(Node *n) {
  43.         if (n == nullptr) {
  44.             return 0;
  45.         }
  46.         return height(n->left) - height(n->right);
  47.     }
  48.  
  49.     Node *insert(int val, Node *n) {
  50.         if (n == nullptr) {
  51.             n = new Node;
  52.             n->value = val;
  53.             n->left = nullptr;
  54.             n->right = nullptr;
  55.             n->height = 1;
  56.             size_++;
  57.             return n;
  58.         }
  59.         if (val < n->value) {
  60.             n->left = insert(val, n->left);
  61.         } else if (val > n->value) {
  62.             n->right = insert(val, n->right);
  63.         } else {
  64.             return n;
  65.         }
  66.         updateHeight(n);
  67.         return rebalance(n);
  68.     }
  69.  
  70.     Node *erase(Node *n, int val) {
  71.         if (n == nullptr) {
  72.             return nullptr;
  73.         }
  74.         if (n->value > val) {
  75.             n->left = erase(n->left, val);
  76.         } else if (n->value < val) {
  77.             n->right = erase(n->right, val);
  78.         } else {
  79.             Node *left = n->left;
  80.             Node *right = n->right;
  81.             delete n;
  82.             size_--;
  83.             if (left == nullptr) {
  84.                 return right;
  85.             } else if (right == nullptr) {
  86.                 return left;
  87.             }
  88.             Node *tmp = right;
  89.             while (tmp->left != nullptr) {
  90.                 tmp = tmp->left;
  91.             }
  92.             tmp->left = left;
  93.             updateHeight(tmp);
  94.             return rebalance(tmp);
  95.         }
  96.         updateHeight(n);
  97.         return rebalance(n);
  98.     }
  99.  
  100.     Node *rebalance(Node *n) {
  101.         int b = balance(n);
  102.         if (b > 1 && balance(n->left) < 0) {
  103.             n->left = leftRotate(n->left);
  104.             return rightRotate(n);
  105.         } else if (b < -1 && balance(n->right) > 0) {
  106.             n->right = rightRotate(n->right);
  107.             return leftRotate(n);
  108.         } else if (b > 1) {
  109.             return rightRotate(n);
  110.         } else if (b < -1) {
  111.             return leftRotate(n);
  112.         }
  113.         return n;
  114.     }
  115.  
  116.     Node *rightRotate(Node *n) {
  117.         Node *rotated = n->left;
  118.         Node *right = rotated->right;
  119.         rotated->right = n;
  120.         n->left = right;
  121.         updateHeight(n);
  122.         updateHeight(rotated);
  123.         return rotated;
  124.     }
  125.  
  126.     Node *leftRotate(Node *n) {
  127.         Node *rotated = n->right;
  128.         Node *left = rotated->left;
  129.         rotated->left = n;
  130.         n->right = left;
  131.         updateHeight(n);
  132.         updateHeight(rotated);
  133.         return rotated;
  134.     }
  135.  
  136.     void updateHeight(Node *n) {
  137.         int h = std::max(height(n->left), height(n->right));
  138.         n->height = h + 1;
  139.     }
  140.  
  141.     Node *find(Node *n, int val) {
  142.         if (n == nullptr) {
  143.             return nullptr;
  144.         }
  145.         if (n->value > val) {
  146.             return find(n->left, val);
  147.         } else if (n->value < val) {
  148.             return find(n->right, val);
  149.         } else {
  150.             return n;
  151.         }
  152.     }
  153.  
  154. public:
  155.     AVLTree() {
  156.         root_ = nullptr;
  157.         size_ = 0;
  158.     }
  159.  
  160.     int getHeight() {
  161.         return height(root_);
  162.     }
  163.  
  164.     Node *getRoot() {
  165.         return root_;
  166.     }
  167.  
  168.     bool empty() {
  169.         return (root_ == nullptr);
  170.     }
  171.  
  172.     int getSize() {
  173.         return size_;
  174.     }
  175.  
  176.     int *find(int value) {
  177.         Node *res = find(root_, value);
  178.         if (res != nullptr) {
  179.             return &(res->value);
  180.         }
  181.         return nullptr;
  182.     }
  183.  
  184.     void insert(int val) {
  185.         root_ = insert(val, root_);
  186.     }
  187.  
  188.     void erase(int value) {
  189.         root_ = erase(root_, value);
  190.     }
  191.  
  192.     int *traversal() {
  193.         if (root_ == nullptr) {
  194.             return nullptr;
  195.         }
  196.         int *arr = new int[getSize()];
  197.         int i = 0;
  198.         traversal(root_, arr, &i);
  199.         return arr;
  200.     }
  201.  
  202.     int *lowerBound(int val) {
  203.         int *arr = traversal();
  204.         int *ans = nullptr;
  205.         int size = getSize();
  206.         for (int i = 0; i < size; ++i) {
  207.             if (arr[i] >= val) {
  208.                 ans = find(arr[i]);
  209.                 break;
  210.             }
  211.         }
  212.         delete arr;
  213.         return ans;
  214.     }
  215.  
  216.     ~AVLTree() {
  217.         clear(root_);
  218.     }
  219. };
  220.  
  221. int main() {
  222.     AVLTree tree;
  223.     std::cout << tree.lowerBound(9) << std::endl;
  224.     tree.erase(1000);
  225.     tree.insert(0);
  226.     tree.insert(1);
  227.     tree.insert(2);
  228.     std::cout << tree.lowerBound(0) << std::endl;
  229.     tree.erase(1);
  230.     tree.erase(2);
  231.     tree.erase(0);
  232.     std::cout << tree.lowerBound(0) << std::endl;
  233.     return 0;
  234. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement