Advertisement
manish

lazy segment tree

Mar 22nd, 2021
699
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.63 KB | None | 0 0
  1. func_lazy = lambda a,b,c,d: (a*c, b + a*d)
  2. func_data = lambda a,b,val: a * val + b
  3.  
  4. class LazySegmentTree:
  5.     def __init__(self, data, padding=0, func=lambda a, b: a + b):
  6.         """initialize the lazy segment tree with data"""
  7.         self._func = func
  8.         self._len = len(data)
  9.         self._size = _size = 1 << (self._len - 1).bit_length()
  10.  
  11.         self.data = [padding] * (2 * _size)
  12.         self.data[_size:_size + self._len] = data
  13.         for i in reversed(range(1, _size)):
  14.             self.data[i] = self.data[2 * i] + self.data[2 * i + 1]
  15.             # TODO
  16.         self._lazy = [1, 0] * (2 * _size)
  17.  
  18.     def _push(self, idx):
  19.         """push query on idx to its children"""
  20.         # Let the children know of the queries
  21.         # TODO
  22.         idx *= 2
  23.         a = self._lazy[idx]
  24.         b = self._lazy[idx + 1]
  25.         if (a, b) != (1, 0):
  26.             b >>= 1
  27.             self._lazy[idx] = 1
  28.             self._lazy[idx + 1] = 0
  29.  
  30.             self.data[idx] = a * self.data[idx] + b
  31.             self.data[idx + 1] = a * self.data[idx + 1] + b
  32.  
  33.             idx *= 2
  34.             self._lazy[idx], self._lazy[idx + 1] = func_lazy(a, b, self._lazy[idx], self._lazy[idx + 1])
  35.             self._lazy[idx + 2], self._lazy[idx + 3] = func_lazy(a, b, self._lazy[idx + 2], self._lazy[idx + 3])
  36.  
  37.     def _build(self, idx):
  38.         """make the changes to idx be known to its ancestors"""
  39.         idx >>= 1
  40.         while idx:
  41.             # TODO
  42.             # self._push(idx)
  43.             a = self._lazy[2 * idx]
  44.             b = self._lazy[2 * idx + 1]
  45.             self.data[idx] = a * (self.data[2 * idx] + self.data[2 * idx + 1]) + b
  46.             idx >>= 1
  47.  
  48.     def _update(self, idx):
  49.         """updates the node idx to know of all queries applied to it via its ancestors"""
  50.         for i in reversed(range(1, idx.bit_length())):
  51.             self._push(idx >> i)
  52.  
  53.     def update(self, start, stop, x):
  54.         self.updateval(start, stop, 0, x)
  55.  
  56.     def add(self, start, stop, x):
  57.         self.updateval(start, stop, 1, x)
  58.  
  59.     def updateval(self, start, stop, a, b):
  60.         """lazily add value to [start, stop)"""
  61.         stop += 1
  62.         start = start_copy = start + self._size
  63.         stop = stop_copy = stop + self._size
  64.  
  65.         # Apply all the lazily stored queries
  66.         self._update(start)
  67.         self._update(stop - 1)
  68.  
  69.         while start < stop:
  70.             if start & 1:
  71.                 self.data[start] = a * self.data[start] + b
  72.                 self._lazy[2 * start], self._lazy[2 * start + 1] = func_lazy(a, b,
  73.                 self._lazy[2 * start], self._lazy[2 * start + 1])
  74.                 start += 1
  75.             if stop & 1:
  76.                 stop -= 1
  77.                 self.data[stop] = a * self.data[stop] + b
  78.                 self._lazy[stop * 2], self._lazy[2 * stop + 1] = func_lazy(a, b,
  79.                 self._lazy[2 * stop], self._lazy[2 * stop + 1])
  80.             start >>= 1
  81.             stop >>= 1
  82.             b <<= 1
  83.  
  84.         self._build(start_copy)
  85.         self._build(stop_copy - 1)
  86.  
  87.     def query(self, start, stop, res=0):
  88.         stop += 1
  89.         start += self._size
  90.         stop += self._size
  91.  
  92.         # Apply all the lazily stored queries
  93.         self._update(start)
  94.         self._update(stop - 1)
  95.         while start < stop:
  96.             if start & 1:
  97.                 res += self.data[start]
  98.                 start += 1
  99.             if stop & 1:
  100.                 stop -= 1
  101.                 res += self.data[stop]
  102.             start >>= 1
  103.             stop >>= 1
  104.         return res
  105.  
  106.     def node_size(self, idx):
  107.         return 1 << self._size.bit_length() - idx.bit_length()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement