tsypko

splay tree

Nov 9th, 2017
82
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 3.09 KB | None | 0 0
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3.  
  4. #define null NULL
  5.  
  6. struct splay_tree {
  7. private:
  8.     struct node {
  9.         int key;
  10.         node *l, *r, *p;
  11.         node() {
  12.             l = r = p = null;
  13.             key = 0;
  14.         }
  15.         node(int key) {
  16.             this->key = key;
  17.             l = r = p = null;
  18.         }
  19.     };
  20.     typedef node *tnode;
  21.     node *root;
  22.     void set_parent(tnode child, tnode parent) {
  23.         if (child) {
  24.             child->p = parent;
  25.         }
  26.     }
  27.     void keep_parent(tnode v) {
  28.         set_parent(v->l, v);
  29.         set_parent(v->r, v);
  30.     }
  31.     void rotate(tnode child, tnode parent) {
  32.         assert(child);
  33.         assert(parent);
  34.         assert(parent->l == child || parent->r == child);
  35.         assert(child->p == parent);
  36.         tnode gparent = parent->p;
  37.         if (gparent) {
  38.             (gparent->l == parent ? gparent->l : gparent->r) = child;
  39.         }
  40.         if (parent->l == child) {
  41.             parent->l = child->r;
  42.             child->r = parent;
  43.         } else {
  44.             parent->r = child->l;
  45.             child->l = parent;
  46.         }
  47.         keep_parent(child);
  48.         keep_parent(parent);
  49.         child->p = gparent;
  50.     }
  51.     tnode splay(tnode v) {
  52.         if (!v->p)
  53.             return v;
  54.         if (!v->p->p) {
  55.             rotate(v, v->p);
  56.             return v;
  57.         }
  58.         tnode parent = v->p;
  59.         tnode gparent = parent->p;
  60.         bool zigzig = (parent->l == v) == (gparent->l == parent);
  61.         if (zigzig) {
  62.             rotate(parent, gparent);
  63.             rotate(v, parent);
  64.         } else {
  65.             rotate(v, parent);
  66.             rotate(v, gparent);
  67.         }
  68.         return splay(v);
  69.     }
  70.     tnode find(tnode v, int key) {
  71.         if (!v)
  72.             return null;
  73.         if (key == v->key) {
  74.             return splay(v);
  75.         }
  76.         if (v->key > key && v->l) {
  77.             return find(v->l, key);
  78.         } else if (v->key < key && v->r) {
  79.             return find(v->r, key);
  80.         }
  81.         return splay(v);
  82.     }
  83.     void split_without_del(tnode t, int key, tnode &l, tnode &r) {
  84.         if (!t) {
  85.             l = r = null;
  86.             return;
  87.         }
  88.         t = find(t, key);
  89.         if (t->key < key) {
  90.             r = t->r;
  91.             set_parent(r, null);
  92.             l = t;
  93.             l->r = null;
  94.         } else {
  95.             l = t->l;
  96.             set_parent(l, null);
  97.             r = t;
  98.             r->l = null;
  99.         }
  100.     }
  101.     void split(tnode t, int key, tnode &l, tnode &r) {
  102.         if (!t) {
  103.             l = r = null;
  104.             return;
  105.         }
  106.         t = find(t, key);
  107.         if (t->key == key) {
  108.             l = t->l;
  109.             r = t->r;
  110.             set_parent(l, null);
  111.             set_parent(r, null);
  112.         } else if (t->key < key) {
  113.             r = t->r;
  114.             set_parent(r, null);
  115.             l = t;
  116.             l->r = null;
  117.         } else {
  118.             l = t->l;
  119.             set_parent(l, null);
  120.             r = t;
  121.             r->l = null;
  122.         }
  123.     }
  124.     tnode merge(tnode l, tnode r) {
  125.         if (!l)
  126.             return r;
  127.         if (!r)
  128.             return l;
  129.         r = find(r, l->key);
  130.         r->l = l;
  131.         l->p = r;
  132.         return r;
  133.     }
  134.     void remove(tnode &root, int key) {
  135.         tnode t = find(root, key);
  136.         tnode l = t->l;
  137.         tnode r = t->r;
  138.         set_parent(l, null);
  139.         set_parent(r, null);
  140.         root = merge(l, r);
  141.     }
  142.     void insert(tnode &root, int key) {
  143.         tnode l, r;
  144.         split(root, key, l, r);
  145.         tnode v = new node(key);
  146.         root = merge(merge(l, v), r);
  147.     }
  148. public:
  149.     splay_tree() {
  150.         root = null;
  151.     }
  152.     void insert(int val) {
  153.         insert(root, val);
  154.     }
  155.     int find(int val) {
  156.         tnode l, r;
  157.         split_without_del(root, val, l, r);
  158.         int ans = -1;
  159.         if (r) {
  160.             r = find(r, INT_MIN);
  161.             ans = r->key;
  162.         }
  163.         root = merge(l, r);
  164.         return ans;
  165.     }
  166. };
  167.  
  168. int main(int argc, char **argv) {
  169.  
  170. }
Add Comment
Please, Sign In to add comment