spider68

Range Sum Query - Mutable

Jun 18th, 2021
648
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. struct SegmentTreeNode {
  2.     int start, end, sum;
  3.     SegmentTreeNode* left;
  4.     SegmentTreeNode* right;
  5.     SegmentTreeNode(int a, int b):start(a),end(b),sum(0),left(nullptr),right(nullptr){}
  6. };
  7. class NumArray {
  8.     SegmentTreeNode* root;
  9. public:
  10.     NumArray(vector<int> &nums) {
  11.         int n = nums.size();
  12.         root = buildTree(nums,0,n-1);
  13.     }
  14.    
  15.     void update(int i, int val) {
  16.         modifyTree(i,val,root);
  17.     }
  18.  
  19.     int sumRange(int i, int j) {
  20.         return queryTree(i, j, root);
  21.     }
  22.     SegmentTreeNode* buildTree(vector<int> &nums, int start, int end) {
  23.         if(start > end) return nullptr;
  24.         SegmentTreeNode* root = new SegmentTreeNode(start,end);
  25.         if(start == end) {
  26.             root->sum = nums[start];
  27.             return root;
  28.         }
  29.         int mid = start + (end - start) / 2;
  30.         root->left = buildTree(nums,start,mid);
  31.         root->right = buildTree(nums,mid+1,end);
  32.         root->sum = root->left->sum + root->right->sum;
  33.         return root;
  34.     }
  35.     int modifyTree(int i, int val, SegmentTreeNode* root) {
  36.         if(root == nullptr) return 0;
  37.         int diff;
  38.         if(root->start == i && root->end == i) {
  39.             diff = val - root->sum;
  40.             root->sum = val;
  41.             return diff;
  42.         }
  43.         int mid = (root->start + root->end) / 2;
  44.         if(i > mid) {
  45.             diff = modifyTree(i,val,root->right);
  46.         } else {
  47.             diff = modifyTree(i,val,root->left);
  48.         }
  49.         root->sum = root->sum + diff;
  50.         return diff;
  51.     }
  52.     int queryTree(int i, int j, SegmentTreeNode* root) {
  53.         if(root == nullptr) return 0;
  54.         if(root->start == i && root->end == j) return root->sum;
  55.         int mid = (root->start + root->end) / 2;
  56.         if(i > mid) return queryTree(i,j,root->right);
  57.         if(j <= mid) return queryTree(i,j,root->left);
  58.         return queryTree(i,mid,root->left) + queryTree(mid+1,j,root->right);
  59.     }
  60. };
RAW Paste Data