Advertisement
welleyth

UJGOI 2022 Day 3C

Aug 26th, 2022
800
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 8.74 KB | Science | 0 0
  1. #include <bits/stdc++.h>
  2. using namespace std;
  3.  
  4. //#define int long long
  5. #define mp make_pair
  6. #define pb push_back
  7.  
  8. #pragma GCC optimize("Ofast")
  9. #pragma GCC optimize("unroll-loops")
  10. #pragma GCC target("avx2")
  11.  
  12. //#pragma GCC optimize("O3,unroll-loops")
  13. //#pragma GCC target("avx2,bmi,bmi2,lzcnt,popcnt")
  14.  
  15. constexpr int INF = (int)1e9;
  16. constexpr int N = (1 << 18);
  17.  
  18. vector<int> g[N];
  19.  
  20. struct HLD{
  21.     int t[4*N];
  22.     int n;
  23.     void init(int sz){
  24.         n = sz-1;
  25.         return;
  26.     }
  27.     struct SegTreeMn{
  28.         struct node{
  29.             int mn,pos;
  30.             int w;
  31.             node(){}
  32.             node(int x,int pos):pos(pos){
  33.                 mn = x;
  34.                 w = 0;
  35.             }
  36.             node(node l,node r){
  37.                 mn = min(l.mn,r.mn);
  38.                 if(mn == l.mn)
  39.                     pos = l.pos;
  40.                 else
  41.                     pos = r.pos;
  42.                 w = 0;
  43.             }
  44.         } t[4*N];
  45.         int n;
  46.         SegTreeMn(){}
  47.         void init(int sz){
  48.             n = sz-1;
  49.             return;
  50.         }
  51.         void push(int v){
  52.             t[v<<1].mn += t[v].w;
  53.             t[v<<1|1].mn += t[v].w;
  54.             t[v<<1].w += t[v].w;
  55.             t[v<<1|1].w += t[v].w;
  56.             t[v].w = 0;
  57.             return;
  58.         }
  59.         void build(int v,int l,int r,vector<int>& order,vector<int>& a){
  60.             if(l == r){
  61.                 t[v] = node(a[order[l]],l);
  62.                 return;
  63.             }
  64.             int m = (l+r)>>1;
  65.             build(v<<1,l,m,order,a);
  66.             build(v<<1|1,m+1,r,order,a);
  67.             t[v] = node(t[v<<1],t[v<<1|1]);
  68.             return;
  69.         }
  70.         void build(vector<int>& order,vector<int>& a){
  71.             n = order.size()-1;
  72.             build(1,0,n,order,a);
  73.             return;
  74.         }
  75.         void upd(int v,int l,int r,int pos,int nw){
  76.             if(l == r){
  77.                 t[v] = node(nw,pos);
  78.                 return;
  79.             }
  80.             push(v);
  81.             int m = (l+r)>>1;
  82.             if(pos <= m)
  83.                 upd(v<<1,l,m,pos,nw);
  84.             else
  85.                 upd(v<<1|1,m+1,r,pos,nw);
  86.             t[v] = node(t[v<<1],t[v<<1|1]);
  87.             return;
  88.         }
  89.         void updSegment(int v,int l,int r,int tl,int tr){
  90.             if(l > r || tl > tr)
  91.                 return;
  92.             if(l == tl && r == tr){
  93.                 t[v].w--;
  94.                 t[v].mn--;
  95.                 return;
  96.             }
  97.             push(v);
  98.             int m = (l+r)>>1;
  99.             updSegment(v<<1,l,m,tl,min(tr,m));
  100.             updSegment(v<<1|1,m+1,r,max(tl,m+1),tr);
  101.             t[v] = node(t[v<<1],t[v<<1|1]);
  102.             return;
  103.         }
  104.         void updSegment(int l,int r){
  105.             updSegment(1,0,n,l,r);
  106.             return;
  107.         }
  108.         void upd(int pos,int nw){
  109.             upd(1,0,n,pos,nw);
  110.             return;
  111.         }
  112.     }
  113.     tMin;
  114.  
  115.     struct SegTree{
  116.         int n;
  117.         int tree[2*N];
  118.         SegTree(){}
  119.         void init(int sz){
  120.             return;
  121.         }
  122.         void upd(int pos, int newval) { // arr[pos] := newval
  123.             pos += N;
  124.             tree[pos] = newval;
  125.             pos >>= 1;
  126.             while (pos > 0) {
  127.                 tree[pos] = tree[pos << 1] + tree[(pos << 1) | 1];
  128.                 pos >>= 1;
  129.             }
  130.         }
  131.         int get(int l, int r) { // [l, r)
  132.             r++;
  133.             l += N;
  134.             r += N;
  135.             int ans = 0;
  136.             while (l < r) {
  137.                 if (l & 1) {
  138.                     ans += tree[l++];
  139.                 }
  140.                 if (r & 1) {
  141.                     ans += tree[--r];
  142.                 }
  143.                 l >>= 1;
  144.                 r >>= 1;
  145.             }
  146.             return ans;
  147.         }
  148.     }
  149.     tSum;
  150.  
  151.     int d[N],sz[N];
  152.     int go[N];
  153.     int parent[N];
  154.     int tin[N],tout[N];
  155.     int head[N];
  156.     int p[N];
  157.  
  158.     void dfs1(int v,int pr = -1){ /// determine heavy paths
  159.         sz[v] = 1;
  160.         int mx = 0;
  161.         go[v] = -1;
  162.         for(auto& to : g[v]){
  163.             if(to == pr)
  164.                 continue;
  165.             dfs1(to,v);
  166.             if(sz[to] > mx){
  167.                 mx = sz[to];
  168.                 go[v] = to;
  169.             }
  170.             sz[v] += sz[to];
  171.         }
  172.         return;
  173.     }
  174.  
  175.     vector<int> order;
  176.  
  177.     void dfs2(int v,int h,int pr = 0){
  178.         p[v] = pr;
  179.         tin[v] = order.size();
  180.         head[v] = h;
  181.         order.pb(v);
  182.         if(go[v] == -1)
  183.             return;
  184.         dfs2(go[v],h,v);
  185.         for(auto& to : g[v]){
  186.             if(to == pr || to == go[v])
  187.                 continue;
  188.             dfs2(to,to,v);
  189.         }
  190.         tout[v] = order.size()-1;
  191.         return;
  192.     }
  193.  
  194.     int cnt[N];
  195.     int b[N],s[N];
  196.     vector<int> was[N];
  197.     vector<int> sons[N];
  198.     int pp[N];
  199.     bool deleted[N];
  200.  
  201.     int find_parent(int x){
  202.         if(deleted[pp[x]])
  203.             pp[x] = find_parent(pp[x]);
  204.         return pp[x];
  205.     }
  206.  
  207.     void dfs(int v,int pr = -1){
  208.         cnt[s[v]]++;
  209.         pp[v] = 0;
  210.         if(!was[s[v]].empty()){
  211.             pp[v] = was[s[v]].back();
  212.             sons[was[s[v]].back()].pb(v);
  213.         }
  214.         was[s[v]].pb(v);
  215.         if(cnt[s[v]] == 1){
  216.             tSum.upd(tin[v],1);
  217. //            cerr << "added in " << v << "\n";
  218.         }
  219.         tMin.upd(tin[v],b[v]);
  220.         for(auto& to : g[v]){
  221.             if(to == pr)
  222.                 continue;
  223.             dfs(to,v);
  224.         }
  225.         cnt[s[v]]--;
  226.         was[s[v]].pop_back();
  227.         return;
  228.     }
  229.  
  230.     int get(int v){
  231.         int ans = 0;
  232. //        cerr << "v = " << v << "\n";
  233. //        cerr << "h = " << head[v] << "\n";
  234.         for(int h = head[v];; h = head[v]){
  235.             ans += tSum.get(tin[h],tin[v]);
  236.             tMin.updSegment(tin[h],tin[v]);
  237.             if(h == 1)
  238.                 break;
  239.             v = p[h];
  240.         }
  241.         return ans;
  242.     }
  243.  
  244.     void Normalize(){
  245.         do{
  246.             auto p = tMin.t[1];
  247.             if(p.mn > 0)
  248.                 break;
  249.             int pos = p.pos;
  250.             int v = order[pos];
  251. //            cerr << "going to delete " << v << " " << p.mn << "\n";
  252.             deleted[v] = true;
  253.             tMin.upd(pos,INF);
  254.             if(tSum.get(tin[v],tin[v]) == 1){
  255.                 tSum.upd(tin[v],0);
  256.                 for(auto& x : sons[v]){
  257.                     if(!deleted[x]){
  258.                         tSum.upd(tin[x],1);
  259. //                        cerr << "added " << x << "\n";
  260.                     }
  261.                 }
  262.             } else {
  263.                 find_parent(v);
  264.                 int pr = pp[v];
  265.                 if(pr != 0){
  266.                     if(sons[v].size() > sons[pr].size()){
  267.                         for(auto& x : sons[pr])
  268.                             sons[v].pb(x);
  269.                         swap(sons[v],sons[pr]);
  270.                     } else {
  271.                         for(auto& x : sons[v])
  272.                             sons[pr].pb(x);
  273.                     }
  274.                 }
  275.             }
  276.             deleted[v] = true;
  277. //            cerr << "deleted " << v << "\n";
  278.         }while(true);
  279.     }
  280.  
  281.     void init(int sz,vector<int>& B,vector<int>& S){
  282.         n = sz-1;
  283.         tMin.init(sz);
  284.         tSum.init(sz);
  285.         dfs1(1);
  286.         dfs2(1,1);
  287.         for(int i = 0; i < B.size(); i++){
  288.             b[i] = B[i];
  289.             s[i] = S[i];
  290.         }
  291.         dfs(1);
  292.         Normalize();
  293.     }
  294.  
  295.     int answer(int v){
  296.         int ans = get(v);
  297.         Normalize();
  298.         return ans;
  299.     }
  300.  
  301. } Solver;
  302.  
  303. void solve(){
  304.     int n,d;
  305. //    cin >> n >> d;
  306.     scanf("%d %d",&n,&d);
  307.  
  308.     vector<int> b(n+1,0);
  309.     vector<int> s(n+1,0);
  310.  
  311.     for(int i = 1; i <= n; i++){
  312. //        cin >> b[i];
  313.         scanf("%d",&b[i]);
  314.     }
  315.     for(int i = 1; i <= n; i++){
  316. //        cin >> s[i];
  317.         scanf("%d",&s[i]);
  318.     }
  319.  
  320.     vector<int> queries;
  321.     for(int i = 0; i < d; i++){
  322.         int v;
  323. //        cin >> v;
  324.         scanf("%d",&v);
  325.         queries.pb(v);
  326.     }
  327.  
  328.     for(int i = 1; i < n; i++){
  329.         int a,b;
  330. //        cin >> a >> b;
  331.         scanf("%d %d",&a,&b);
  332.         g[a].pb(b);
  333.         g[b].pb(a);
  334.     }
  335.  
  336.     Solver.init(n,b,s);
  337.     for(auto& v : queries){
  338. //        cerr << "Next QUERY!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n";
  339.         int ans = Solver.answer(v);
  340. //        cerr << "ANSWER = " << ans << "\n\n";
  341.         printf("%d ",ans);
  342.     }
  343.  
  344.     return;
  345. }
  346.  
  347. signed main(){
  348.     ios::sync_with_stdio(false);cin.tie(nullptr);cout.tie(nullptr);
  349.     int tests = 1;
  350. //    cin >> tests;
  351.     for(int test = 1; test <= tests; test++){
  352. //        cerr << "test = " << test << "\n";
  353.         solve();
  354.     }
  355.     return 0;
  356. }
  357. /**
  358. O(Nlog^2)
  359.  
  360. 5 1
  361. 4 1 0 3 1
  362. 1 3 2 2 1
  363. 2
  364. 5 2
  365. 2 1
  366. 1 4
  367. 1 3
  368.  
  369. **/
  370.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement