Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <bits/stdc++.h>
- #include <ext/pb_ds/assoc_container.hpp>
- #include <ext/pb_ds/tree_policy.hpp>
- using namespace __gnu_pbds;
- using namespace std;
- #define int long long
- #define pb push_back
- #define mp make_pair
- #pragma GCC optimize("Ofast")
- #pragma GCC optimize("unroll-loops")
- #pragma GCC target("sse,sse2,sse3,ssse3,sse4,popcnt,abm,mmx,tune=native")
- #pragma GCC target("avx,avx2")
- constexpr int INF = (int)1e9;
- constexpr int N = (int)2e5 + 11;
- constexpr int MAXN = (int)3000 + 11;
- constexpr int md = (int)1e9+7;
- void add(int& a,int b){
- a += b;
- if(a >= md) a -= md;
- }
- vector<int> g[N];
- int sz[N];
- int parent[N];
- int depth[N];
- int heavy[N];
- int up[20][N];
- int tin[N],tout[N];
- int timer = 0;
- int col[N];
- void dfs(int v,int pr = 0){
- sz[v] = 1;
- parent[v] = pr;
- tin[v] = timer++;
- up[0][v] = pr;
- for(int i = 1; i < 20; i++){
- up[i][v] = up[i-1][up[i-1][v]];
- }
- for(auto& x : g[v]){
- if(x == pr)
- continue;
- depth[x] = depth[v] + 1;
- dfs(x,v);
- if(heavy[v] == 0 || sz[x] > sz[heavy[v]])
- heavy[v] = x;
- sz[v] += sz[x];
- }
- tout[v] = timer++;
- return;
- }
- bool upper(int a,int b){
- return tin[a] <= tin[b] && tout[a] >= tout[b];
- }
- int lca(int a,int b){
- if(upper(a,b))
- return a;
- if(upper(b,a))
- return b;
- for(int i = 19; i >= 0; i--){
- if(up[i][a] != 0 && !upper(up[i][a],b))
- a = up[i][a];
- }
- return up[0][a];
- }
- vector<int> a;
- int pos[N];
- int head[N];
- void hld(int v,int h){
- pos[v] = a.size();
- a.push_back(v);
- head[v] = h;
- if(heavy[v] != 0)
- hld(heavy[v],h);
- for(auto& x : g[v]){
- if(x == parent[v] || x == heavy[v])
- continue;
- hld(x,x);
- }
- return;
- }
- struct node{
- int ans = 0;
- int col_r = 0,col_l = 0;
- node(){}
- node(int x){
- ans = 0;
- col_l = col_r = x;
- }
- } t[4*N];
- int prom[4*N];
- node combine(node a,node b){
- if(a.col_l == 0)
- return b;
- if(b.col_l == 0)
- return a;
- node c;
- c.ans = a.ans + b.ans + (a.col_r != b.col_l);
- c.col_l = a.col_l;
- c.col_r = b.col_r;
- return c;
- }
- void push(int v){
- if(prom[v] == 0)
- return;
- t[v<<1] = node(prom[v]);
- t[v<<1|1] = node(prom[v]);
- prom[v<<1] = prom[v<<1|1] = prom[v];
- prom[v] = 0;
- return;
- }
- void build(int v,int l,int r){
- if(l == r){
- t[v] = node(col[a[l]]);
- return;
- }
- int m = (l+r)>>1;
- build(v<<1,l,m);
- build(v<<1|1,m+1,r);
- t[v] = combine(t[v<<1],t[v<<1|1]);
- }
- void upd(int v,int l,int r,int tl,int tr,int col){
- if(l > r || tl > tr)
- return;
- // cerr << "upd = " << v << " " << l+1 << " " << r+1 << " " << tl+1 << " " << tr+1 << " " << col << "\n";
- if(l == tl && r == tr){
- t[v] = node(col);
- prom[v] = col;
- return;
- }
- push(v);
- int m = (l+r)>>1;
- upd(v<<1,l,m,tl,min(tr,m),col);
- upd(v<<1|1,m+1,r,max(tl,m+1),tr,col);
- t[v] = combine(t[v<<1],t[v<<1|1]);
- return;
- }
- node get(int v,int l,int r,int tl,int tr){
- if(l > r || tl > tr)
- return node();
- // cerr << l+1 << " " << r+1 << ", ans = " << t[v].ans << "\n";
- // cerr << "colors = " << t[v].col_l << " " << t[v].col_r << "\n";
- if(l == tl && r == tr)
- return t[v];
- push(v);
- int m = (l+r)>>1;
- return combine(get(v<<1,l,m,tl,min(tr,m)),get(v<<1|1,m+1,r,max(tl,m+1),tr));
- }
- int n;
- node query(int a,int b){ /// a is upper than b
- if(upper(b,a))
- swap(a,b);
- node cur = node();
- for(; head[a] != head[b]; b = parent[head[b]]){
- cur = combine(get(1,0,n-1,pos[head[b]],pos[b]),cur);
- }
- cur = combine(get(1,0,n-1,pos[a],pos[b]),cur);
- return cur;
- }
- int get(int a,int b){
- int L = lca(a,b);
- node A = query(L,a);
- node B = query(L,b);
- // cerr << "get " << L << " " << a << " = " << A.ans << "\n";
- // cerr << "get " << L << " " << b << " = " << B.ans << "\n\n";
- // swap(B.col_l,B.col_r);
- return A.ans + B.ans;
- }
- void upd(int a,int b,int c){ /// a is upper than b
- if(upper(b,a))
- swap(a,b);
- for(; head[a] != head[b]; b = parent[head[b]]){
- // if (depth[head[a]] > depth[head[b]])
- // swap(a, b);
- upd(1,0,n-1,pos[head[b]],pos[b],c);
- }
- upd(1,0,n-1,pos[a],pos[b],c);
- return;
- }
- void updPath(int a,int b,int c){
- int l = lca(a,b);
- upd(l,a,c);
- upd(l,b,c);
- return;
- }
- void solve(){
- cin >> n;
- for(int i = 1; i < n; i++){
- int a,b;
- cin >> a >> b;
- g[a].pb(b);
- g[b].pb(a);
- }
- dfs(1);
- hld(1,1);
- for(int i = 1; i <= n; i++){
- cin >> col[i];
- }
- build(1,0,n-1);
- // cerr << get(1,1) << "\n";
- // return;
- // for(auto& x : a){
- // cerr << x << " ";
- // }
- // cerr << "\n";
- // updPath(6,4,2);
- // cerr << get(2,3) << "\n";
- // return;
- int q;
- cin >> q;
- while(q--){
- int tp;
- cin >> tp;
- if(tp == 1){
- int u,v,col;
- cin >> u >> v >> col;
- updPath(u,v,col);
- }
- if(tp == 2){
- int u,c;
- cin >> u >> c;
- upd(1,0,n-1,pos[u],pos[u]+sz[u]-1,c);
- }
- if(tp == 3){
- int u,v;
- cin >> u >> v;
- cout << get(u,v) << "\n";
- }
- }
- return;
- }
- signed main(){
- // ifstream cin("colors.01.in");
- // ofstream cout("output.txt");
- ios::sync_with_stdio(false);cin.tie(nullptr);cout.tie(nullptr);
- // freopen("colors.01.in","r",stdin);
- // freopen("output.txt","w",stdout);
- int tests = 1;
- // cin >> tests;
- while(tests--){
- solve();
- }
- }
- /**
- 1 2 3 4 5 6
- **/
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement