Advertisement
Guest User

Untitled

a guest
Oct 23rd, 2016
84
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 15.16 KB | None | 0 0
  1. #include <iostream>
  2. #include <chrono>
  3. #include <random>
  4. #include <cassert>
  5.  
  6. using namespace std;
  7.  
  8. class Sequence
  9. {
  10. public:
  11.   Sequence(int n) : m_rbTree(n)
  12.   {
  13.     for (int i = 0; i < n; ++i)
  14.     {
  15.       m_rbTree.insert(i);
  16.     }
  17.   }
  18.  
  19.   int get(int i) const
  20.   {
  21.     assert(i >= 0 && i < size() && "Index out of range.");
  22.     return m_rbTree.get(i);
  23.   }
  24.  
  25.   void erase(int i)
  26.   {
  27.     assert(i >= 0 && i < size() && "Index out of range.");
  28.     m_rbTree.erase(i);
  29.   }
  30.  
  31.   int size() const
  32.   {
  33.     return m_rbTree.size();
  34.   }
  35.  
  36. private:
  37.  
  38.   class RedBlackTree
  39.   {
  40.   public:
  41.     RedBlackTree(size_t nodesCount)
  42.       : m_pool(nodesCount + 1)
  43.     {
  44.       void* ptr = m_pool.allocate();
  45.       m_nil = new(ptr)Node();
  46.       m_root = m_nil;
  47.  
  48.       m_nil->parent     = nullptr;
  49.       m_nil->leftChild  = nullptr;
  50.       m_nil->rightChild = nullptr;
  51.       m_nil->key        = -1;
  52.       m_nil->size       =  0;
  53.       m_nil->color      = Color::BLACK;
  54.     }
  55.  
  56.     void insert(int key)
  57.     {
  58.       void* ptr = m_pool.allocate();
  59.       Node* node = new(ptr)Node(key);
  60.       insert(node);
  61.     }
  62.  
  63.     int get(int i) const
  64.     {
  65.       Node* n = select(m_root, i + 1);
  66.       return n->key;
  67.     }
  68.  
  69.     void erase(int i)
  70.     {
  71.       Node* n = select(m_root, i + 1);
  72.       remove(n);
  73.     }
  74.  
  75.     int size() const
  76.     {
  77.       return m_root->size;
  78.     }
  79.  
  80.   private:
  81.     enum class Color : uint8_t { RED, BLACK };
  82.  
  83.     struct Node
  84.     {
  85.       Node() = default;
  86.       Node(int key_) : key(key_) {}
  87.  
  88.       Node* parent;
  89.       Node* leftChild;
  90.       Node* rightChild;
  91.       int   key;
  92.       int   size;
  93.       Color color;
  94.     };
  95.  
  96.     class MemoryPool
  97.     {
  98.     public:
  99.       MemoryPool(size_t nodesCount)
  100.         : m_memory(new Node[nodesCount])
  101.         , m_currentFreeNode(m_memory)
  102.       {}
  103.  
  104.       ~MemoryPool()
  105.       {
  106.         delete[] m_memory;
  107.       }
  108.  
  109.       Node* allocate()
  110.       {
  111.         auto result = m_currentFreeNode;
  112.         ++m_currentFreeNode;
  113.         return result;
  114.       }
  115.  
  116.     private:
  117.       Node* m_memory;
  118.       Node* m_currentFreeNode;
  119.     };
  120.  
  121.   private:
  122.     Node* select(Node* x, int i) const
  123.     {
  124.       int r = x->leftChild->size + 1;
  125.  
  126.       if (i == r)
  127.         return x;
  128.       else if (i < r)
  129.         return select(x->leftChild , i);
  130.       else
  131.         return select(x->rightChild, i - r);
  132.     }
  133.  
  134.     void insert(Node* z)
  135.     {
  136.       Node* y = m_nil;
  137.       Node* x = m_root;
  138.  
  139.       while (x != m_nil)
  140.       {
  141.         ++x->size;
  142.         y = x;
  143.         if (z->key < x->key)
  144.           x = x->leftChild;
  145.         else
  146.           x = x->rightChild;
  147.       }
  148.  
  149.       z->parent = y;
  150.  
  151.       if (y == m_nil)
  152.         m_root = z;
  153.       else if (z->key < y->key)
  154.         y->leftChild  = z;
  155.       else
  156.         y->rightChild = z;
  157.  
  158.       z->leftChild  = m_nil;
  159.       z->rightChild = m_nil;
  160.       z->size       = 1;
  161.       z->color      = Color::RED;
  162.  
  163.       insertFixup(z);
  164.     }
  165.  
  166.     void insertFixup(Node* z)
  167.     {
  168.       assert(z);
  169.       assert(z->parent);
  170.  
  171.       while (z->parent->color == Color::RED)
  172.       {
  173.         assert(z->parent->parent);
  174.  
  175.         if (z->parent == z->parent->parent->leftChild)
  176.         {
  177.           Node* y = z->parent->parent->rightChild;
  178.           assert(y);
  179.  
  180.           if (y->color == Color::RED)
  181.           {
  182.             z->parent->color = Color::BLACK;
  183.             y->color = Color::BLACK;
  184.             z->parent->parent->color = Color::RED;
  185.             z = z->parent->parent;
  186.           }
  187.           else
  188.           {
  189.             if (z == z->parent->rightChild)
  190.             {
  191.               z = z->parent;
  192.               leftRotate(z);
  193.             }
  194.             z->parent->color = Color::BLACK;
  195.             z->parent->parent->color = Color::RED;
  196.             rightRotate(z->parent->parent);
  197.           }
  198.         }
  199.         else
  200.         {
  201.           Node* y = z->parent->parent->leftChild;
  202.           assert(y);
  203.  
  204.           if (y->color == Color::RED)
  205.           {
  206.             z->parent->color = Color::BLACK;
  207.             y->color = Color::BLACK;
  208.             z->parent->parent->color = Color::RED;
  209.             z = z->parent->parent;
  210.           }
  211.           else
  212.           {
  213.             if (z == z->parent->leftChild)
  214.             {
  215.               z = z->parent;
  216.               rightRotate(z);
  217.             }
  218.             z->parent->color = Color::BLACK;
  219.             z->parent->parent->color = Color::RED;
  220.             leftRotate(z->parent->parent);
  221.           }
  222.         }
  223.       }
  224.  
  225.       m_root->color = Color::BLACK;
  226.     }
  227.  
  228.     void remove(Node* z)
  229.     {
  230.       assert(z);
  231.       assert(z != m_nil);
  232.  
  233.       Node* x;
  234.       Node* y = z;
  235.       Color yOriginalColor = y->color;
  236.  
  237.       // decrement sizes on path to the root
  238.       while (y != m_nil)
  239.       {
  240.         --y->size;
  241.         y = y->parent;
  242.       }
  243.  
  244.       if (z->leftChild == m_nil)
  245.       {
  246.         x = z->rightChild;
  247.         transplant(z, z->rightChild);
  248.       }
  249.       else if (z->rightChild == m_nil)
  250.       {
  251.         x = z->leftChild;
  252.         transplant(z, z->leftChild);
  253.       }
  254.       else
  255.       {
  256.         y = findMinimum(z->rightChild);
  257.         yOriginalColor = y->color;
  258.         x = y->rightChild;
  259.  
  260.         if (y->parent == z)
  261.         {
  262.           x->parent = y;
  263.         }
  264.         else
  265.         {
  266.           transplant(y, y->rightChild);
  267.           y->rightChild = z->rightChild;
  268.           y->rightChild->parent = y;
  269.  
  270.           Node* temp = x->parent;
  271.           while (temp != y)
  272.           {
  273.             --temp->size;
  274.             temp = temp->parent;
  275.           }
  276.  
  277.           y->size = y->leftChild->size + 1;
  278.         }
  279.  
  280.         transplant(z, y);
  281.         y->leftChild = z->leftChild;
  282.         y->leftChild->parent = y;
  283.         y->color = z->color;
  284.         y->size = y->leftChild->size + y->rightChild->size + 1;
  285.       }
  286.  
  287.       if (yOriginalColor == Color::BLACK)
  288.       {
  289.         removeFixup(x);
  290.       }
  291.     }
  292.  
  293.     void removeFixup(Node* x)
  294.     {
  295.       assert(x);
  296.  
  297.       while (x != m_root && x->color == Color::BLACK)
  298.       {
  299.         if (x == x->parent->leftChild)
  300.         {
  301.           Node* w = x->parent->rightChild;
  302.  
  303.           assert(w);
  304.           assert(w != m_nil);
  305.  
  306.           if (w->color == Color::RED)
  307.           {
  308.             w->color = Color::BLACK;
  309.             x->parent->color = Color::RED;
  310.             leftRotate(x->parent);
  311.             w = x->parent->rightChild;
  312.           }
  313.  
  314.           if (w->leftChild->color  == Color::BLACK &&
  315.               w->rightChild->color == Color::BLACK)
  316.           {
  317.             w->color = Color::RED;
  318.             x = x->parent;
  319.           }
  320.           else
  321.           {
  322.             if (w->rightChild->color == Color::BLACK)
  323.             {
  324.               w->leftChild->color = Color::BLACK;
  325.               w->color = Color::RED;
  326.               rightRotate(w);
  327.               w = x->parent->rightChild;
  328.             }
  329.  
  330.             w->color = x->parent->color;
  331.             x->parent->color = Color::BLACK;
  332.             w->rightChild->color = Color::BLACK;
  333.             leftRotate(x->parent);
  334.             x = m_root;
  335.           }
  336.         }
  337.         else
  338.         {
  339.           Node* w = x->parent->leftChild;
  340.  
  341.           assert(w);
  342.           assert(w != m_nil);
  343.  
  344.           if (w->color == Color::RED)
  345.           {
  346.             w->color = Color::BLACK;
  347.             x->parent->color = Color::RED;
  348.             rightRotate(x->parent);
  349.             w = x->parent->leftChild;
  350.           }
  351.  
  352.           if (w->leftChild->color  == Color::BLACK &&
  353.               w->rightChild->color == Color::BLACK)
  354.           {
  355.             w->color = Color::RED;
  356.             x = x->parent;
  357.           }
  358.           else
  359.           {
  360.             if (w->leftChild->color == Color::BLACK)
  361.             {
  362.               w->rightChild->color = Color::BLACK;
  363.               w->color = Color::RED;
  364.               leftRotate(w);
  365.               w = x->parent->leftChild;
  366.             }
  367.  
  368.             w->color = x->parent->color;
  369.             x->parent->color = Color::BLACK;
  370.             w->leftChild->color = Color::BLACK;
  371.             rightRotate(x->parent);
  372.             x = m_root;
  373.           }
  374.         }
  375.       }
  376.  
  377.       x->color = Color::BLACK;
  378.     }
  379.  
  380.     void leftRotate(Node* x)
  381.     {
  382.       assert(x);
  383.  
  384.       Node* y = x->rightChild;
  385.       x->rightChild = y->leftChild;
  386.  
  387.       if (y->leftChild != m_nil)
  388.         y->leftChild->parent = x;
  389.  
  390.       y->parent = x->parent;
  391.  
  392.       if (x->parent == m_nil)
  393.         m_root = y;
  394.       else if (x == x->parent->leftChild)
  395.         x->parent->leftChild  = y;
  396.       else
  397.         x->parent->rightChild = y;
  398.  
  399.       y->leftChild = x;
  400.       x->parent = y;
  401.       y->size = x->size;
  402.       x->size = x->leftChild->size + x->rightChild->size + 1;
  403.     }
  404.  
  405.     void rightRotate(Node* x)
  406.     {
  407.       assert(x);
  408.  
  409.       Node* y = x->leftChild;
  410.       x->leftChild = y->rightChild;
  411.  
  412.       if (y->rightChild != m_nil)
  413.         y->rightChild->parent = x;
  414.  
  415.       y->parent = x->parent;
  416.  
  417.       if (x->parent == m_nil)
  418.         m_root = y;
  419.       else if (x == x->parent->rightChild)
  420.         x->parent->rightChild = y;
  421.       else
  422.         x->parent->leftChild  = y;
  423.  
  424.       y->rightChild = x;
  425.       x->parent = y;
  426.       y->size = x->size;
  427.       x->size = x->leftChild->size +x->rightChild->size + 1;
  428.     }
  429.  
  430.     void transplant(Node* u, Node* v)
  431.     {
  432.       if (u->parent == m_nil)
  433.         m_root = v;
  434.       else if (u == u->parent->leftChild)
  435.         u->parent->leftChild  = v;
  436.       else
  437.         u->parent->rightChild = v;
  438.  
  439.       v->parent = u->parent;
  440.     }
  441.  
  442.     Node* findMinimum(Node* x) const
  443.     {
  444.       assert(x);
  445.       assert(x != m_nil);
  446.  
  447.       while (x->leftChild != m_nil)
  448.         x = x->leftChild;
  449.  
  450.       return x;
  451.     }
  452.  
  453.   private:
  454.     MemoryPool m_pool;
  455.     Node* m_nil;
  456.     Node* m_root;
  457.  
  458.   }; // RedBlackTree class
  459.  
  460. private:
  461.   RedBlackTree m_rbTree;
  462.  
  463. }; // Sequence class
  464.  
  465.  
  466. #ifdef __PRETTY_FUNCTION__
  467.   #define TEST_ASSERT(condition)                                           \
  468.     if (!(condition)) {                                                    \
  469.       std::cerr << "Assertion failed at " << __FILE__ << ":" << __LINE__;  \
  470.       std::cerr << " inside '" << __PRETTY_FUNCTION__ << "'" <<std::endl;  \
  471.       std::cerr << "Condition: " << #condition << std::endl;               \
  472.       abort();                                                             \
  473.     }
  474. #else
  475.   #define TEST_ASSERT(condition)                                           \
  476.     if (!(condition)) {                                                    \
  477.       std::cerr << "Assertion failed at " << __FILE__ << ":" << __LINE__;  \
  478.       std::cerr << "Condition: " << #condition << std::endl;               \
  479.       abort();                                                             \
  480.     }
  481. #endif
  482.  
  483. void test1()
  484. {
  485.   Sequence s(10);
  486.   for (int i = 0; i < 10; ++i)
  487.   {
  488.     TEST_ASSERT(s.get(i) == i);
  489.   }
  490. }
  491.  
  492. void test2()
  493. {
  494.   Sequence s(10);
  495.   s.erase(0);
  496.   s.erase(0);
  497.   s.erase(0);
  498.   for (int i = 0; i < 7; ++i)
  499.   {
  500.     TEST_ASSERT(s.get(i) == i + 3);
  501.   }
  502. }
  503.  
  504. void test3()
  505. {
  506.   Sequence s(10);
  507.   s.erase(9);
  508.   s.erase(0);
  509.   s.erase(7);
  510.   s.erase(0);
  511.   for (int i = 0; i < 6; ++i)
  512.   {
  513.     TEST_ASSERT(s.get(i) == i + 2);
  514.   }
  515. }
  516.  
  517. void test4()
  518. {
  519.   Sequence s(10);
  520.   s.erase(9);
  521.   s.erase(5);
  522.   s.erase(7);
  523.   s.erase(2);
  524.   s.erase(0);
  525.   s.erase(0);
  526.   s.erase(0);
  527.   s.erase(0);
  528.   s.erase(1);
  529.   TEST_ASSERT(s.get(0) == 6);
  530. }
  531.  
  532. void test5()
  533. {
  534.   Sequence s(10);
  535.   s.erase(5);
  536.   TEST_ASSERT(s.get(5) == 6);
  537.   s.erase(8);
  538.   s.erase(0);
  539.   TEST_ASSERT(s.get(5) == 7);
  540.   s.erase(3);
  541.   TEST_ASSERT(s.get(3) == 6);
  542.   s.erase(4);
  543.   TEST_ASSERT(s.get(0) == 1);
  544.   TEST_ASSERT(s.get(1) == 2);
  545.   TEST_ASSERT(s.get(2) == 3);
  546.   TEST_ASSERT(s.get(3) == 6);
  547.   TEST_ASSERT(s.get(4) == 8);
  548. }
  549.  
  550. void test6()
  551. {
  552.   Sequence s(3);
  553.   s.erase(1);
  554.   TEST_ASSERT(s.get(0) == 0);
  555.   TEST_ASSERT(s.get(1) == 2);
  556.   s.erase(1);
  557.   TEST_ASSERT(s.get(0) == 0);
  558.   s.erase(0);
  559. }
  560.  
  561. void eraseAlwaysZeroIndexTest()
  562. {
  563.   static constexpr int SIZE = 10'000'000;
  564.  
  565.   cout << "Starting eraseAlwaysZeroIndexTest with " << SIZE << " elements ..." << endl;
  566.  
  567.   auto timeBegin = chrono::high_resolution_clock::now();
  568.  
  569.   Sequence s(SIZE);
  570.   for (int i = 0; i < SIZE; ++i)
  571.   {
  572.     s.erase(0);
  573.   }
  574.  
  575.   auto timeEnd = chrono::high_resolution_clock::now();
  576.  
  577.   cout << "eraseAlwaysZeroIndexTest(): "
  578.        << chrono::duration_cast<chrono::milliseconds>(
  579.             timeEnd - timeBegin).count() << " ms." << endl;
  580. }
  581.  
  582. void eraseAlwaysLastIndexTest()
  583. {
  584.   static constexpr int SIZE = 10'000'000;
  585.  
  586.   cout << "Starting eraseAlwaysLastIndexTest with " << SIZE << " elements ..." << endl;
  587.  
  588.   auto timeBegin = chrono::high_resolution_clock::now();
  589.  
  590.   Sequence s(SIZE);
  591.   for (int i = 0; i < SIZE; ++i)
  592.   {
  593.     s.erase(SIZE - i - 1);
  594.   }
  595.  
  596.   auto timeEnd = chrono::high_resolution_clock::now();
  597.  
  598.   cout << "eraseAlwaysLastIndexTest(): "
  599.        << chrono::duration_cast<chrono::milliseconds>(
  600.             timeEnd - timeBegin).count() << " ms." << endl;
  601. }
  602.  
  603. void eraseRandomIndexTest()
  604. {
  605.   static constexpr int SIZE = 10'000'000;
  606.  
  607.   cout << "Starting eraseRandomIndexTest with " << SIZE << " elements ..." << endl;
  608.  
  609.   auto timeBegin = chrono::high_resolution_clock::now();
  610.  
  611.   random_device rd;
  612.   mt19937 gen(rd());
  613.  
  614.   Sequence s(SIZE);
  615.   for (int i = 0; i < SIZE; ++i)
  616.   {
  617.     uniform_int_distribution<> distribution(0, SIZE - i - 1);
  618.     s.erase(distribution(gen));
  619.   }
  620.  
  621.   auto timeEnd = chrono::high_resolution_clock::now();
  622.  
  623.   cout << "eraseRandomIndexTest(): "
  624.        << chrono::duration_cast<chrono::milliseconds>(
  625.             timeEnd - timeBegin).count() << " ms." << endl;
  626. }
  627.  
  628. class ArraySequence
  629. {
  630. public:
  631.   ArraySequence(int n)
  632.     : m_size(n)
  633.     , m_data(new int[m_size])
  634.   {
  635.     for (int i = 0; i < m_size; ++i)
  636.     {
  637.       m_data[i] = i;
  638.     }
  639.   }
  640.  
  641.   ~ArraySequence()
  642.   {
  643.     delete[] m_data;
  644.   }
  645.  
  646.   int get(int i) const
  647.   {
  648.     assert(i >= 0 && i < m_size && "Index out of range.");
  649.     return m_data[i];
  650.   }
  651.  
  652.   void erase(int i)
  653.   {
  654.     assert(i >= 0 && i < m_size && "Index out of range.");
  655.     --m_size;
  656.     for (int j = i; j < m_size; ++j)
  657.     {
  658.       m_data[j] = m_data[j + 1];
  659.     }
  660.   }
  661.  
  662.   int size() const
  663.   {
  664.     return m_size;
  665.   }
  666.  
  667. private:
  668.   int  m_size;
  669.   int* m_data;
  670. };
  671.  
  672.  
  673. void correctnessTest()
  674. {
  675.   static constexpr int SIZE = 10'000;
  676.  
  677.  cout << "Starting correctness test with " << SIZE << " elements ..." << endl;
  678.  
  679.  ArraySequence arraySequence(SIZE);
  680.  Sequence rbtSequence(SIZE);
  681.  
  682.  random_device rd;
  683.  mt19937 gen(rd());
  684.  
  685.  for (int i = 0; i < SIZE; ++i)
  686.  {
  687.    uniform_int_distribution<> distribution(0, SIZE - i - 1);
  688.  
  689.    int index = distribution(gen);
  690.    arraySequence.erase(index);
  691.    rbtSequence.erase(index);
  692.  
  693.    TEST_ASSERT(arraySequence.size() == rbtSequence.size());
  694.    for (int j = 0; j < arraySequence.size(); ++j)
  695.    {
  696.      TEST_ASSERT(arraySequence.get(j) == rbtSequence.get(j));
  697.    }
  698.  }
  699.  
  700.  cout << "Correctness test passed." << endl;
  701. }
  702.  
  703. int main()
  704. {
  705.  test1();
  706.  test2();
  707.  test3();
  708.  test4();
  709.  test5();
  710.  test6();
  711.  
  712.  correctnessTest();
  713.  
  714.  eraseAlwaysZeroIndexTest();
  715.  eraseAlwaysLastIndexTest();
  716.  eraseRandomIndexTest();
  717.  
  718.  return 0;
  719. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement