Advertisement
volochai

Heavy-Light Decomposition

Jul 11th, 2022
77
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 3.72 KB | None | 0 0
  1. struct tree {
  2.     int n;
  3.     vector < int > t;
  4.  
  5.     void build(vector<int> &x) {
  6.         n = x.size();
  7.  
  8.         int tree_size = 1;
  9.         while ((1ll << tree_size) < n)
  10.             tree_size++;
  11.  
  12.         t.resize(4 * (1ll << tree_size));
  13.  
  14.         build(x, 1, 0, n - 1);
  15.     }
  16.  
  17.     void build(vector <int> &a, int v, int tl, int tr) {
  18.         if (tl == tr) {
  19.             t[v] = a[tl];
  20.             return;
  21.         }
  22.         int tm = (tl + tr) >> 1;
  23.         build(a, v << 1, tl, tm);
  24.         build(a, v << 1 | 1, tm + 1, tr);
  25.         t[v] = t[v << 1] + t[v << 1 | 1];
  26.     }
  27.  
  28.     void update(int pos, int val, int v, int tl, int tr) {
  29.         if (tl > pos || tr < pos)
  30.             return;
  31.         if (tl == tr && pos == tl) {
  32.             t[v] = val;
  33.             return;
  34.         }
  35.         int tm = (tl + tr) >> 1;
  36.         update(pos, val, v << 1, tl, tm);
  37.         update(pos, val, v << 1 | 1, tm + 1, tr);
  38.         t[v] = t[v << 1] + t[v << 1 | 1];
  39.     }
  40.  
  41.     int get(int l, int r, int v, int tl, int tr) {
  42.         if (tl > r || tr < l)
  43.             return 0;
  44.         if (l <= tl && tr <= r)
  45.             return t[v];
  46.         int tm = (tl + tr) >> 1;
  47.         return get(l, r, v << 1, tl, tm) +
  48.                 get(l, r, v << 1 | 1, tm + 1, tr);
  49.     }
  50.  
  51.     void update(int pos, int val) {
  52.         update(pos, val, 1, 0, n - 1);
  53.     }
  54.  
  55.     int get(int l, int r) {
  56.         if (l > r) swap(l, r);
  57.         return get(l, r, 1, 0, n - 1);
  58.     }
  59. };
  60.  
  61. struct HLD {
  62.     int n, root;
  63.     vector < vector <int> > g;
  64.  
  65.     vector < int > sz;
  66.     vector < int > depth;
  67.     vector < int > parent;
  68.  
  69.     vector < int > decompose;
  70.     vector < int > pos;
  71.     tree t;
  72.  
  73.     vector < int > top;
  74.  
  75.     void build(vector < vector<int> > &adj, vector < int > &cost, int _root) {
  76.         n = adj.size() - 1;
  77.         root = _root;
  78.         g = adj;
  79.  
  80.         sz.resize(n + 1);
  81.         depth.resize(n + 1);
  82.         parent.resize(n + 1);
  83.  
  84.         init(root);
  85.  
  86.         top.resize(n + 1);
  87.  
  88.         dfs(root, root);
  89.  
  90.         pos.resize(n + 1);
  91.         for (int i = 0; i < n; i++)
  92.             pos[decompose[i]] = i;
  93.  
  94.         for (int i = 1; i <= n; i++)
  95.             decompose[i - 1] = cost[decompose[i - 1]];
  96.  
  97.         t.build(decompose);
  98.     }
  99.  
  100.     void init(int v, int d = 0, int p = 1) {
  101.         sz[v]++;
  102.         depth[v] = d;
  103.         parent[v] = p;
  104.  
  105.         for (auto to : g[v])
  106.             if (to != p)
  107.                 init(to, d + 1, v), sz[v] += sz[to];
  108.     }
  109.  
  110.     void dfs(int v, int cur_top) {
  111.         decompose.push_back(v);
  112.         top[v] = cur_top;
  113.  
  114.         if (v != root && g[v].size() == 1)
  115.             return;
  116.  
  117.         int next_heavy = 0;
  118.         for (auto to : g[v])
  119.             if (sz[to] > sz[next_heavy] && to != parent[v])
  120.                 next_heavy = to;
  121.  
  122.         dfs(next_heavy, cur_top);
  123.  
  124.         for (auto to : g[v])
  125.             if (to != next_heavy && to != parent[v])
  126.                 dfs(to, to);
  127.     }
  128.  
  129.     int lca(int u, int v) {
  130.         while (top[u] != top[v])
  131.             if (depth[top[u]] > depth[top[v]])
  132.                 u = parent[top[u]];
  133.             else
  134.                 v = parent[top[v]];
  135.         return (depth[u] < depth[v] ? u : v);
  136.     }
  137.  
  138.     int get(int u, int v) {
  139.         int l = lca(u, v);
  140.         int res = 0;
  141.         while (top[u] != top[l]) {
  142.             res += t.get(pos[u], pos[top[u]]);
  143.             u = parent[top[u]];
  144.         }
  145.         res += t.get(pos[u], pos[l]);
  146.         while (top[v] != top[l]) {
  147.             res += t.get(pos[v], pos[top[v]]);
  148.             v = parent[top[v]];
  149.         }
  150.         res += t.get(pos[v], pos[l]);
  151.         return res - t.get(pos[l], pos[l]);
  152.     }
  153.  
  154.     void update(int s, int val) {
  155.         t.update(pos[s], val);
  156.     }
  157. };
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement