Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <bits/stdc++.h>
- #define pub push_back
- #define ll long long
- #define mp make_pair
- #define x first
- #define y second
- #define all(a) a.begin(), a.end()
- using namespace std;
- const int mod = (int)1e9 + 7;
- void add(int &a, int b){
- a += b;
- if (a >= mod) a -= mod;
- }
- void mult(int &a, int b){
- a = (ll)a * b % mod;
- }
- struct vert{
- int y, l, r, sz, sum, p, func, mi, ma;
- pair<int, int> val;
- vert() { l = -1, r = -1, y = rand() * (ll)rand() % mod, sz = 1, sum = 0, p = 1, func = 0; }
- };
- vert t[30 * 200007];
- pair<int, int> tmp[200007];
- int kk = 0;
- int sz = 0;
- struct dd{
- int root;
- dd() { root = -1; }
- void push(int v){
- if (v == -1 || t[v].p == 1) return;
- mult(t[v].sum, t[v].p);
- mult(t[v].val.y, t[v].p);
- mult(t[v].func, t[v].p);
- if (t[v].l != -1) mult(t[t[v].l].p, t[v].p);
- if (t[v].r != -1) mult(t[t[v].r].p, t[v].p);
- t[v].p = 1;
- }
- int getSz(int v){
- if (v == -1) return 0;
- return t[v].sz;
- }
- int getSum(int v){
- if (v == -1) return 0;
- return t[v].sum;
- }
- int getFunc(int v){
- if (v == -1) return 0;
- return t[v].func;
- }
- int getMi(int v){
- if (v == -1) return 1e9 + 7;
- return t[v].mi;
- }
- int getMa(int v){
- if (v == -1) return -1;
- return t[v].ma;
- }
- void recalc(int v){
- if (v == -1) return;
- push(t[v].l);
- push(t[v].r);
- t[v].sz = getSz(t[v].l) + getSz(t[v].r) + 1;
- t[v].mi = min(t[v].val.x, min(getMi(t[v].l), getMi(t[v].r)));
- t[v].ma = max(t[v].val.x, max(getMa(t[v].l), getMa(t[v].r)));
- t[v].func = getFunc(t[v].l) + getFunc(t[v].r);
- if (t[v].func >= mod) t[v].func -= mod;
- t[v].func += t[v].val.x * (ll)t[v].val.y % mod;
- if (t[v].func >= mod) t[v].func -= mod;
- t[v].sum = getSum(t[v].l) + getSum(t[v].r);
- if (t[v].sum >= mod) t[v].sum -= mod;
- t[v].sum += t[v].val.y;
- if (t[v].sum >= mod) t[v].sum -= mod;
- }
- int merge(int a, int b){
- push(a), push(b);
- if (a == -1) return b;
- if (b == -1) return a;
- if (t[a].y < t[b].y){
- t[a].r = merge(t[a].r, b);
- recalc(a);
- return a;
- } else {
- t[b].l = merge(a, t[b].l);
- recalc(b);
- return b;
- }
- }
- pair<int, int> splitVal(int v, int val){
- if (v == -1) return mp(-1, -1);
- push(v);
- if (t[v].val.x >= val){
- auto now = splitVal(t[v].l, val);
- t[v].l = now.y;
- recalc(v);
- return mp(now.x, v);
- } else {
- auto now = splitVal(t[v].r, val);
- t[v].r = now.x;
- recalc(v);
- return mp(v, now.y);
- }
- }
- int qGetNumVal(int v, int val){
- if (v == -1) return -1;
- push(v);
- if (t[v].val.x == val) return v;
- if (t[v].val.x > val) return qGetNumVal(t[v].l, val);
- return qGetNumVal(t[v].r, val);
- }
- int qGetSumVal(int v, int x){
- if (v == -1) return 0;
- push(v);
- if (t[v].val.x < x){
- int ans = t[v].val.y;
- push(t[v].l);
- add(ans, getSum(t[v].l));
- add(ans, qGetSumVal(t[v].r, x));
- return ans;
- } else {
- return qGetSumVal(t[v].l, x);
- }
- }
- void qMultGo(int v, int l, int r, int val){
- if (v == -1) return;
- push(v);
- int vl = t[v].mi, vr = t[v].ma;
- if (r < vl || l > vr) return;
- if (vl >= l && vr <= r){
- mult(t[v].p, val);
- return;
- }
- if (t[v].l != -1) qMultGo(t[v].l, l, r, val);
- if (t[v].r != -1) qMultGo(t[v].r, l, r, val);
- if (t[v].val.x >= l && t[v].val.x <= r){
- mult(t[v].val.y, val);
- }
- recalc(v);
- }
- void qDeleteVal(int val){
- auto now = splitVal(root, val);
- root = now.y;
- }
- void insertGo(int v, pair<int, int> val){
- add(t[v].sum, val.y);
- add(t[v].func, val.x * (ll)val.y % mod);
- if (t[v].val.x == val.x){
- add(t[v].val.y, val.y);
- return;
- }
- if (t[v].val.x > val.x) insertGo(t[v].l, val);
- else insertGo(t[v].r, val);
- }
- int insertNew(int v, int uk){
- if (v == -1) return uk;
- if (t[v].y > t[uk].y){
- auto now = splitVal(v, t[uk].val.x);
- t[uk].l = now.x;
- t[uk].r = now.y;
- recalc(uk);
- return uk;
- } else {
- if (t[v].val.x > t[uk].val.x){
- int d = insertNew(t[v].l, uk);
- t[v].l = d;
- } else {
- int d = insertNew(t[v].r, uk);
- t[v].r = d;
- }
- recalc(v);
- return v;
- }
- }
- void insert(pair<int, int> val){
- int q = qGetNumVal(root, val.x);
- if (q == -1){
- t[sz].val = val;
- t[sz].sum = val.y;
- t[sz].func = t[sz].val.x * (ll)t[sz].val.y % mod;
- t[sz].mi = val.x;
- t[sz].ma = val.x;
- //auto now = splitVal(root, val.x);
- //root = merge(now.x, sz);
- //root = merge(root, now.y);
- root = insertNew(root, sz);
- sz++;
- } else {
- insertGo(root, val);
- /*auto now = splitVal(root, val.x);
- auto now2 = splitVal(now.y, val.x + 1);
- add(t[now2.x].val.y, val.y);
- add(t[now2.x].sum, val.y);
- t[now2.x].func = t[now2.x].val.x * (ll)t[now2.x].val.y % mod;
- root = merge(now.x, now2.x);
- root = merge(root, now2.y);*/
- }
- }
- void go(int v){
- if (v == -1) return;
- push(v);
- go(t[v].l);
- tmp[kk++] = t[v].val;
- go(t[v].r);
- }
- int qGetFunc(){
- return getFunc(root);
- }
- };
- int n;
- int color[200007];
- vector<int> g[200007];
- int ans = 0;
- dd dp[200007];
- int foo[200007];
- void dfs(int v, int pred){
- for (int to : g[v]) if (to != pred){
- dfs(to, v);
- if (dp[v].getSz(dp[v].root) < dp[to].getSz(dp[to].root)) swap(dp[v], dp[to]);
- kk = 0; dp[to].go(dp[to].root);
- for (int i = 0; i < kk; i++){
- foo[i] = dp[v].qGetSumVal(dp[v].root, tmp[i].x) + 1;
- }
- int sum = 1;
- for (int i = 0; i < kk; i++) add(sum, tmp[i].y);
- int last = 1e9 + 7;
- for (int i = kk - 1; i >= 0; i--){
- dp[v].qMultGo(dp[v].root, tmp[i].x, last - 1, sum);
- last = tmp[i].x;
- sum -= tmp[i].y;
- if (sum < 0) sum += mod;
- }
- for (int i = 0; i < kk; i++){
- dp[v].insert(mp(tmp[i].x, tmp[i].y * (ll)foo[i] % mod));
- }
- }
- int sum = dp[v].qGetSumVal(dp[v].root, color[v]) + 1;
- if (sum >= mod) sum -= mod;
- dp[v].qDeleteVal(color[v]);
- dp[v].insert(mp(color[v], sum));
- add(ans, dp[v].qGetFunc());
- //kk = 0;
- //dp[v].go(dp[v].root);
- //for (int i = 0; i < kk; i++) add(ans, tmp[i].x * (ll)tmp[i].y % mod);
- }
- dd test;
- const bool is_testing = 0;
- int main(){
- srand(time(NULL));
- if (is_testing){
- freopen("input.txt", "r", stdin);
- freopen("output.txt", "w", stdout);
- }
- scanf("%d", &n);
- for (int i = 0; i < n; i++) scanf("%d", &color[i]);
- for (int i = 0; i < n - 1; i++){
- int a, b;
- scanf("%d %d", &a, &b);
- a--; b--;
- g[a].pub(b);
- g[b].pub(a);
- }
- dfs(0, -1);
- cout << ans;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement