Advertisement
Guest User

Untitled

a guest
Jun 27th, 2017
49
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 7.47 KB | None | 0 0
  1. #include <bits/stdc++.h>
  2.        
  3. #define pub push_back
  4. #define ll long long
  5. #define mp make_pair
  6. #define x first
  7. #define y second
  8. #define all(a) a.begin(), a.end()
  9.        
  10. using namespace std;
  11.  
  12. const int mod = (int)1e9 + 7;
  13.  
  14. void add(int &a, int b){
  15.     a += b;
  16.     if (a >= mod) a -= mod;
  17. }
  18.  
  19. void mult(int &a, int b){
  20.     a = (ll)a * b % mod;
  21. }
  22.  
  23. struct vert{
  24.     int y, l, r, sz, sum, p, func, mi, ma;
  25.     pair<int, int> val;
  26.     vert() { l = -1, r = -1, y = rand() * (ll)rand() % mod, sz = 1, sum = 0, p = 1, func = 0; }
  27. };
  28.  
  29. vert t[30 * 200007];
  30. pair<int, int> tmp[200007];
  31. int kk = 0;
  32. int sz = 0;
  33.  
  34. struct dd{
  35.     int root;
  36.     dd() { root = -1; }
  37.  
  38.     void push(int v){
  39.         if (v == -1 || t[v].p == 1) return;
  40.         mult(t[v].sum, t[v].p);
  41.         mult(t[v].val.y, t[v].p);
  42.         mult(t[v].func, t[v].p);
  43.         if (t[v].l != -1) mult(t[t[v].l].p, t[v].p);
  44.         if (t[v].r != -1) mult(t[t[v].r].p, t[v].p);
  45.         t[v].p = 1;
  46.     }
  47.  
  48.     int getSz(int v){
  49.         if (v == -1) return 0;
  50.         return t[v].sz;
  51.     }
  52.  
  53.     int getSum(int v){
  54.         if (v == -1) return 0;
  55.         return t[v].sum;
  56.     }
  57.  
  58.     int getFunc(int v){
  59.         if (v == -1) return 0;
  60.         return t[v].func;
  61.     }
  62.  
  63.     int getMi(int v){
  64.         if (v == -1) return 1e9 + 7;
  65.         return t[v].mi;
  66.     }
  67.  
  68.     int getMa(int v){
  69.         if (v == -1) return -1;
  70.         return t[v].ma;
  71.     }
  72.  
  73.     void recalc(int v){
  74.         if (v == -1) return;
  75.         push(t[v].l);
  76.         push(t[v].r);
  77.         t[v].sz = getSz(t[v].l) + getSz(t[v].r) + 1;
  78.         t[v].mi = min(t[v].val.x, min(getMi(t[v].l), getMi(t[v].r)));
  79.         t[v].ma = max(t[v].val.x, max(getMa(t[v].l), getMa(t[v].r)));
  80.         t[v].func = getFunc(t[v].l) + getFunc(t[v].r);
  81.         if (t[v].func >= mod) t[v].func -= mod;
  82.         t[v].func += t[v].val.x * (ll)t[v].val.y % mod;
  83.         if (t[v].func >= mod) t[v].func -= mod;
  84.         t[v].sum = getSum(t[v].l) + getSum(t[v].r);
  85.         if (t[v].sum >= mod) t[v].sum -= mod;
  86.         t[v].sum += t[v].val.y;
  87.         if (t[v].sum >= mod) t[v].sum -= mod;
  88.     }
  89.  
  90.     int merge(int a, int b){
  91.         push(a), push(b);
  92.         if (a == -1) return b;
  93.         if (b == -1) return a;
  94.         if (t[a].y < t[b].y){
  95.             t[a].r = merge(t[a].r, b);
  96.             recalc(a);
  97.             return a;
  98.         } else {
  99.             t[b].l = merge(a, t[b].l);
  100.             recalc(b);
  101.             return b;
  102.         }
  103.     }
  104.  
  105.     pair<int, int> splitVal(int v, int val){
  106.         if (v == -1) return mp(-1, -1);
  107.         push(v);
  108.         if (t[v].val.x >= val){
  109.             auto now = splitVal(t[v].l, val);
  110.             t[v].l = now.y;
  111.             recalc(v);
  112.             return mp(now.x, v);
  113.         } else {
  114.             auto now = splitVal(t[v].r, val);
  115.             t[v].r = now.x;
  116.             recalc(v);
  117.             return mp(v, now.y);
  118.         }
  119.     }
  120.  
  121.     int qGetNumVal(int v, int val){
  122.         if (v == -1) return -1;
  123.         push(v);
  124.         if (t[v].val.x == val) return v;
  125.         if (t[v].val.x > val) return qGetNumVal(t[v].l, val);
  126.         return qGetNumVal(t[v].r, val);
  127.     }
  128.  
  129.     int qGetSumVal(int v, int x){
  130.         if (v == -1) return 0;
  131.         push(v);
  132.         if (t[v].val.x < x){
  133.             int ans = t[v].val.y;
  134.             push(t[v].l);
  135.             add(ans, getSum(t[v].l));
  136.             add(ans, qGetSumVal(t[v].r, x));
  137.             return ans;
  138.         } else {
  139.             return qGetSumVal(t[v].l, x);
  140.         }
  141.     }
  142.    
  143.     void qMultGo(int v, int l, int r, int val){
  144.         if (v == -1) return;
  145.         push(v);
  146.         int vl = t[v].mi, vr = t[v].ma;
  147.         if (r < vl || l > vr) return;
  148.         if (vl >= l && vr <= r){
  149.             mult(t[v].p, val);
  150.             return;
  151.         }
  152.         if (t[v].l != -1) qMultGo(t[v].l, l, r, val);
  153.         if (t[v].r != -1) qMultGo(t[v].r, l, r, val);
  154.         if (t[v].val.x >= l && t[v].val.x <= r){
  155.             mult(t[v].val.y, val);
  156.         }
  157.         recalc(v);
  158.     }
  159.  
  160.     void qDeleteVal(int val){
  161.         auto now = splitVal(root, val);
  162.         root = now.y;
  163.     }
  164.  
  165.     void insertGo(int v, pair<int, int> val){
  166.         add(t[v].sum, val.y);
  167.         add(t[v].func, val.x * (ll)val.y % mod);
  168.         if (t[v].val.x == val.x){
  169.             add(t[v].val.y, val.y);
  170.             return;
  171.         }
  172.         if (t[v].val.x > val.x) insertGo(t[v].l, val);
  173.         else insertGo(t[v].r, val);
  174.     }
  175.  
  176.     int insertNew(int v, int uk){
  177.         if (v == -1) return uk;
  178.         if (t[v].y > t[uk].y){
  179.             auto now = splitVal(v, t[uk].val.x);
  180.             t[uk].l = now.x;
  181.             t[uk].r = now.y;
  182.             recalc(uk);
  183.             return uk;
  184.         } else {
  185.             if (t[v].val.x > t[uk].val.x){
  186.                 int d = insertNew(t[v].l, uk);
  187.                 t[v].l = d;
  188.             } else {
  189.                 int d = insertNew(t[v].r, uk);
  190.                 t[v].r = d;
  191.             }
  192.             recalc(v);
  193.             return v;
  194.         }
  195.     }
  196.  
  197.     void insert(pair<int, int> val){
  198.         int q = qGetNumVal(root, val.x);
  199.         if (q == -1){
  200.             t[sz].val = val;
  201.             t[sz].sum = val.y;
  202.             t[sz].func = t[sz].val.x * (ll)t[sz].val.y % mod;
  203.             t[sz].mi = val.x;
  204.             t[sz].ma = val.x;
  205.             //auto now = splitVal(root, val.x);
  206.             //root = merge(now.x, sz);
  207.             //root = merge(root, now.y);
  208.             root = insertNew(root, sz);
  209.             sz++;
  210.         } else {
  211.             insertGo(root, val);
  212.             /*auto now = splitVal(root, val.x);
  213.             auto now2 = splitVal(now.y, val.x + 1);
  214.             add(t[now2.x].val.y, val.y);
  215.             add(t[now2.x].sum, val.y);
  216.             t[now2.x].func = t[now2.x].val.x * (ll)t[now2.x].val.y % mod;
  217.             root = merge(now.x, now2.x);
  218.             root = merge(root, now2.y);*/
  219.         }
  220.     }
  221.  
  222.     void go(int v){
  223.         if (v == -1) return;
  224.         push(v);
  225.         go(t[v].l);
  226.         tmp[kk++] = t[v].val;
  227.         go(t[v].r);
  228.     }
  229.  
  230.     int qGetFunc(){
  231.         return getFunc(root);
  232.     }
  233. };
  234.  
  235. int n;
  236. int color[200007];
  237. vector<int> g[200007];
  238. int ans = 0;
  239. dd dp[200007];
  240.  
  241. int foo[200007];
  242.  
  243. void dfs(int v, int pred){
  244.     for (int to : g[v]) if (to != pred){
  245.         dfs(to, v);
  246.         if (dp[v].getSz(dp[v].root) < dp[to].getSz(dp[to].root)) swap(dp[v], dp[to]);
  247.         kk = 0; dp[to].go(dp[to].root);
  248.         for (int i = 0; i < kk; i++){
  249.             foo[i] = dp[v].qGetSumVal(dp[v].root, tmp[i].x) + 1;
  250.         }
  251.  
  252.         int sum = 1;
  253.         for (int i = 0; i < kk; i++) add(sum, tmp[i].y);
  254.         int last = 1e9 + 7;
  255.         for (int i = kk - 1; i >= 0; i--){
  256.             dp[v].qMultGo(dp[v].root, tmp[i].x, last - 1, sum);
  257.             last = tmp[i].x;
  258.             sum -= tmp[i].y;
  259.             if (sum < 0) sum += mod;
  260.         }
  261.  
  262.         for (int i = 0; i < kk; i++){
  263.             dp[v].insert(mp(tmp[i].x, tmp[i].y * (ll)foo[i] % mod));
  264.         }
  265.     }
  266.     int sum = dp[v].qGetSumVal(dp[v].root, color[v]) + 1;
  267.     if (sum >= mod) sum -= mod;
  268.     dp[v].qDeleteVal(color[v]);
  269.  
  270.     dp[v].insert(mp(color[v], sum));
  271.     add(ans, dp[v].qGetFunc());
  272.  
  273.     //kk = 0;
  274.     //dp[v].go(dp[v].root);
  275.     //for (int i = 0; i < kk; i++) add(ans, tmp[i].x * (ll)tmp[i].y % mod);
  276.  
  277. }
  278.  
  279. dd test;
  280.  
  281. const bool is_testing = 0;
  282. int main(){
  283.     srand(time(NULL));
  284.     if (is_testing){
  285.         freopen("input.txt", "r", stdin);
  286.         freopen("output.txt", "w", stdout);
  287.     }
  288.     scanf("%d", &n);
  289.     for (int i = 0; i < n; i++) scanf("%d", &color[i]);
  290.     for (int i = 0; i < n - 1; i++){
  291.         int a, b;
  292.         scanf("%d %d", &a, &b);
  293.         a--; b--;
  294.         g[a].pub(b);
  295.         g[b].pub(a);
  296.     }
  297.     dfs(0, -1);
  298.     cout << ans;
  299. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement