Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #ifdef _DEBUG
- #define _GLIBCXX_DEBUG
- #endif
- #define _CRT_SECURE_NO_WARNINGS
- #include <bits/stdc++.h>
- using namespace std;
- typedef long long ll;
- typedef vector<ll> vll;
- typedef pair<int, int> pii;
- typedef vector<int> vi;
- typedef long double ld;
- #define mk make_pair
- #define inb push_back
- #define enb emplace_back
- #define X first
- #define Y second
- #define all(v) v.begin(), v.end()
- #define sqr(x) (x) * (x)
- #define TIME 1.0 * clock() / CLOCKS_PER_SEC
- #define y1 AYDARBOG
- //continue break pop_back return
- int solve();
- int main()
- {
- //ios_base::sync_with_stdio(0);
- //cin.tie(0);
- #define TASK "mail"
- #ifndef _DEBUG
- //freopen(TASK".in", "r", stdin), freopen(TASK".out", "w", stdout);
- #endif
- solve();
- #ifdef _DEBUG
- fprintf(stderr, "\nTIME: %.3f\n", TIME);
- #endif
- }
- const int BUFSZ = (int)1e6 + 7;
- char buf[BUFSZ];
- string get_str()
- {
- scanf(" %s", buf);
- return string(buf);
- }
- int solve()
- {
- int n, q;
- scanf("%d %d", &n, &q);
- vector<vi> g(n);
- for (int i = 0; i < n - 1; ++i)
- {
- int x, y;
- scanf("%d %d", &x, &y);
- --x, --y;
- g[x].inb(y);
- g[y].inb(x);
- }
- // LCA BEGIN
- vi lh(n);
- vector<vi> up(18, vi(n));
- function<void(int)> ldfs = [&](int u)
- {
- for (int to : g[u])
- {
- if (up[0][u] == to)
- continue;
- up[0][to] = u;
- lh[to] = lh[u] + 1;
- ldfs(to);
- }
- };
- ldfs(0);
- for (int l = 1; l < 18; ++l)
- {
- for (int i = 0; i < n; ++i)
- {
- up[l][i] = up[l - 1][up[l - 1][i]];
- }
- }
- auto getdst = [&](int x, int y)
- {
- int a = x, b = y;
- if (lh[a] < lh[b])
- swap(a, b);
- int dst = lh[a] - lh[b];
- for (int l = 17; l >= 0; --l)
- {
- if ((dst >> l) & 1)
- a = up[l][a];
- }
- int lca = a;
- if (a != b)
- {
- for (int l = 17; l >= 0; --l)
- {
- if (up[l][a] != up[l][b])
- a = up[l][a], b = up[l][b];
- }
- lca = up[0][a];
- }
- return lh[x] + lh[y] - 2 * lh[lca];
- };
- // LCA END
- // BUILD CD BEGIN
- vi h(n, -1), p(n, -1);
- function<int(int, int, int&, int)> dfs = [&](int u, int sz, int &cur, int pr)
- {
- int sum = 1;
- for (int to : g[u])
- {
- if (to == pr || h[to] != -1)
- continue;
- sum += dfs(to, sz, cur, u);
- }
- if (cur == -1)
- {
- if (2 * sum >= sz || pr == -1)
- cur = u;
- }
- return sum;
- };
- vector<vector<pii>> pos(n); //position
- vector<vll> sumdst(2, vll(n)); //sum distance 4 each col
- vector<vi> sumcnt(2, vi(n)); //sum cnt
- vi col(n); //cvet
- vector<vector<vll>> csd(n, vector<vll>(2)); //cmp sumdst
- vector<vector<vi>> cntv(n, vector<vi>(2)); //cnt v cmp
- function<void(int, int, int, int, int)> calc = [&](int u, int center, int id, int r, int pp)
- {
- pos[u].inb(mk(center, id));
- //printf("u : %d dst : %d\n", u + 1, r);
- csd[center][0][id] += r;
- ++cntv[center][0][id];
- for (int to : g[u])
- {
- if (to == pp || h[to] != -1)
- continue;
- calc(to, center, id, r + 1, u);
- }
- };
- function<void(int, int, int, int)> build = [&](int u, int sz, int ch, int lst)
- {
- int cur = -1, kek;
- int csz = dfs(u, sz, kek, -1);
- dfs(u, csz, cur, -1);
- //printf("NEW CENTER AND SIZE: %d %d\n", cur + 1, csz);
- h[cur] = ch;
- p[cur] = lst;
- for (int i = 0; i < 2; ++i)
- {
- csd[cur][i].resize((int)g[cur].size());
- cntv[cur][i].resize((int)g[cur].size());
- }
- int it = 0;
- sumcnt[0][cur] = 1;
- for (int to : g[cur])
- {
- if (h[to] != -1)
- continue;
- calc(to, cur, it, 1, cur);
- //printf("GOT AWAY FROM : %d %lld %d\n", to + 1, csd[cur][0][it], cntv[cur][0][it]);
- sumdst[0][cur] += csd[cur][0][it];
- sumcnt[0][cur] += cntv[cur][0][it];
- ++it;
- }
- //printf("%lld %d\n", sumdst[0][cur], sumcnt[0][cur]);
- for (int to : g[cur])
- {
- if (h[to] != -1)
- continue;
- build(to, sz / 2, ch + 1, cur);
- }
- };
- build(0, n, 0, -1);
- // BUILD CD END
- while (q--)
- {
- int cmd;
- scanf("%d", &cmd);
- --cmd;
- int u;
- scanf("%d", &u);
- --u;
- if (cmd)
- {
- //getans
- int cc = col[u];
- ll ans = sumdst[cc][u];
- for (auto &cur : pos[u])
- {
- int center = cur.X;
- int id = cur.Y;
- ll cv = sumcnt[cc][center] - cntv[center][cc][id];
- ll cd = sumdst[cc][center] - csd[center][cc][id];
- ans += cv * getdst(center, u) + cd;
- }
- printf("%lld\n", ans);
- }
- else
- {
- //update
- --sumcnt[col[u]][u];
- col[u] ^= 1;
- int cc = col[u];
- ++sumcnt[col[u]][u];
- for (auto &cur : pos[u])
- {
- int center = cur.X;
- int id = cur.Y;
- --sumcnt[cc ^ 1][center];
- ++sumcnt[cc][center];
- --cntv[center][cc ^ 1][id];
- ++cntv[center][cc][id];
- int dd = getdst(center, u);
- //printf("DAMN DISTANCE: %d %d %d\n", center + 1, u + 1, dd);
- sumdst[cc ^ 1][center] -= dd;
- sumdst[cc][center] += dd;
- csd[center][cc ^ 1][id] -= dd;
- csd[center][cc][id] += dd;
- }
- }
- }
- return 0;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement