Advertisement
welleyth

UJGOI 2022 D. Anton the guard

Aug 17th, 2023
923
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 2.80 KB | None | 0 0
  1. #include <bits/stdc++.h>
  2.  
  3. using namespace std;
  4.  
  5. #pragma GCC optimize("Ofast")
  6. #pragma GCC optimize("unroll-loops")
  7. #pragma GCC target("avx2")
  8.  
  9. constexpr int N = (int)1e6+11;
  10. constexpr long long INF = (long long)1e18;
  11.  
  12. vector<pair<int,int>> initialGraph[N];
  13. long long w[N];
  14. vector<int> g[N],compressedGraph[N];
  15. int deep[N];
  16. pair<long long,int> mn[N][2];
  17. int p[N];
  18. long long dp[N];
  19.  
  20. void dfs_down_edges(int v,int pr = -1){
  21.     for(auto&[to,len] : initialGraph[v]){
  22.         if(to == pr) continue;
  23.         p[to] = v;
  24.         w[to] = w[v] + len;
  25.         g[v].push_back(to);
  26.         deep[to] = deep[v] + 1;
  27.         dfs_down_edges(to,v);
  28.     }
  29.     return;
  30. }
  31.  
  32. int dfs_compress(int v,int depth,long long& total_weight){
  33.     if(deep[v] == depth){
  34.         return v;
  35.     }
  36.     compressedGraph[v].clear();
  37.     mn[v][0] = mn[v][1] = make_pair(INF,-1);
  38.     for(auto& to : g[v]){
  39.         int u = dfs_compress(to,depth,total_weight);
  40.         if(u != -1){
  41.             compressedGraph[v].push_back(u);
  42.             if(deep[u] == depth){
  43.                 long long len = w[p[u]] - w[v];
  44.                 mn[v][1] = min(mn[v][1],make_pair(dp[p[u]]-len,u));
  45.             } else {
  46.                 long long len = w[u] - w[v];
  47.                 mn[v][1] = min(mn[v][1],make_pair(mn[u][0].first-len,u));
  48.             }
  49.             if(mn[v][0] > mn[v][1]) swap(mn[v][0],mn[v][1]);
  50.         }
  51.     }
  52.     swap(g[v],compressedGraph[v]);
  53.     if(g[v].size() > 1){
  54.         for(auto& to : g[v]){
  55.             total_weight += w[to] - w[v]; /// add the length of the edges
  56.         }
  57.         return v;
  58.     }
  59.     else
  60.         return (g[v].empty() ? -1 : g[v][0]);
  61. }
  62.  
  63. void dfs_recalc_dp(int v,int depth,long long min_len,const long long& total_weight){
  64.     if(deep[v] == depth){
  65.         int len = w[v] - w[p[v]];
  66.         min_len = min(min_len,dp[p[v]]-len);
  67.         dp[v] = 2 * total_weight + min_len;
  68.         return;
  69.     }
  70.     for(auto& to : g[v]){
  71.         long long len = w[to] - w[v];
  72.         long long next_min_len = min_len - len;
  73.         bool goMinVertex = (mn[v][0].second == to);
  74.         next_min_len = min(next_min_len,mn[v][goMinVertex].first-len);
  75.         dfs_recalc_dp(to,depth,next_min_len,total_weight);
  76.     }
  77.     return;
  78. }
  79.  
  80. signed main(){
  81.     ios::sync_with_stdio(false);cin.tie(nullptr);cout.tie(nullptr);
  82.    
  83.     int n;
  84.     cin >> n;
  85.  
  86.     for(int i = 2; i <= n; i++){
  87.         cin >> p[i];
  88.     }
  89.  
  90.     for(int i = 2; i <= n; i++){
  91.         cin >> w[i];
  92.     }
  93.  
  94.     for(int i = 2; i <= n; i++){
  95.         initialGraph[p[i]].push_back(make_pair(i,w[i]));
  96.         initialGraph[i].push_back(make_pair(p[i],w[i]));
  97.     }
  98.  
  99.     int root = 1;
  100.  
  101.     deep[root] = 1;
  102.     dfs_down_edges(root);
  103.     int max_depth = 1;
  104.  
  105.     for(int depth = 2; depth <= n; depth++){
  106.         long long total_weight = 0;
  107.         root = dfs_compress(root,depth,total_weight);
  108.         if(root == -1) break;
  109.         if(deep[root] == depth){
  110.             total_weight = w[root] - w[p[root]];
  111.         }
  112.         max_depth = depth;
  113.         dfs_recalc_dp(root,depth,INF,total_weight);
  114.     }
  115.  
  116.     long long ans = INF;
  117.     for(int i = 1; i <= n; i++){
  118.         if(deep[i] == max_depth)
  119.             ans = min(ans,dp[i]);
  120.     }
  121.  
  122.     cout << ans << "\n";
  123.  
  124.     return 0;
  125. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement