Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <iostream>
- #include <vector>
- #include <set>
- #include <tuple>
- #include <fstream>
- using namespace std;
- const int maxN = 1e5 + 1;
- const int mod = 1e9 + 7;
- struct fenwick_tree
- {
- int n;
- vector<long long int> bit, actual;
- fenwick_tree(int n)
- {
- this -> n = n;
- bit.assign(n, 0);
- actual.assign(n, 0);
- }
- void update(int pos, long long int replacement)
- {
- int delta = ((replacement - actual[pos])%mod + mod)%mod;
- for(int at = pos; at < n; at |= at + 1)
- {
- bit[at] += delta;
- bit[at]%= mod;
- }
- actual[pos] = replacement;
- }
- int sum(int r)
- {
- int res = 0;
- for (int at = r; at >= 0; at = (at&(at+1))-1)
- {
- res += bit[at];
- res %= mod;
- }
- return res;
- }
- };
- vector<pair<int, pair<int, int>>> adj[maxN];
- int tour[2*maxN], node_toll[maxN], total_dist[maxN], depth[maxN], par[maxN];
- fenwick_tree total_toll(2*maxN), to_node(2*maxN), from_node(2*maxN);
- pair<int, int> in_out[maxN];
- int st[2*maxN][20], log[2*maxN];
- int tour_counter = 0;
- void update_to_node(int root, int new_toll)
- {
- long long int replacement = (long long) total_dist[par[root]] * new_toll;
- replacement %= mod;
- to_node.update(in_out[root].first, replacement);
- to_node.update(in_out[root].second, -replacement);
- }
- void update_from_node(int root, int new_toll)
- {
- long long int replacement = (long long) total_dist[root]*new_toll;
- replacement %= mod;
- from_node.update(in_out[root].first, replacement);
- from_node.update(in_out[root].second, -replacement);
- }
- void update_total_toll(int root, int new_toll)
- {
- total_toll.update(in_out[root].first, new_toll);
- total_toll.update(in_out[root].second, -new_toll);
- node_toll[root] = new_toll;
- }
- void update_all(int root, int toll)
- {
- update_to_node(root, toll);
- update_from_node(root, toll);
- update_total_toll(root, toll);
- }
- int toll(pair<int, pair<int, int>> a)
- {
- return a.second.first;
- }
- int dist(pair<int, pair<int, int> > a)
- {
- return a.second.second;
- }
- int child(pair<int, pair<int, int> > a)
- {
- return a.first;
- }
- void calculate_city_info(int root, int tolls)
- {
- cout << "first" << endl;
- tour[tour_counter] = root;
- in_out[root].first = tour_counter;
- tour_counter++;
- for (pair<int, pair<int, int> > edge: adj[root])
- {
- if (child(edge) != par[root])
- {
- par[child(edge)] = root;
- depth[child(edge)] = depth[root] + 1;
- total_dist[child(edge)] = total_dist[root] + dist(edge); total_dist[child(edge)] %= mod;
- cout << "Second" << endl;
- calculate_city_info(child(edge), toll(edge));
- tour[tour_counter] = root;
- tour_counter++;
- }
- }
- in_out[root].second = tour_counter;
- update_all(root, tolls);
- }
- int higher(int a, int b)
- {
- return depth[a] < depth[b] ? a: b;
- }
- int find_lca(int first, int second)
- {
- int length = in_out[second].first - in_out[first].first + 1;
- int one = st[in_out[first].first][log[length]];
- int next_start = in_out[second].first - (1 << log[length]) + 1;
- int two = st[next_start][log[length]];
- return higher(one, two);
- }
- int to_root(int vertex)
- {
- int minus = from_node.sum(in_out[vertex].first);
- minus = (minus+mod)%mod;
- int total = ((long long) total_dist[vertex] * total_toll.sum(in_out[vertex].first))%mod;
- total = (total+mod)%mod;
- return (total - minus + mod)%mod;
- }
- int from_root(int vertex)
- {
- return (to_node.sum(in_out[vertex].first)+mod)%mod;
- }
- int find_to_lca(int start, int end, int lca, int to_end_toll)
- {
- int dist_to_lca = (total_dist[start] - total_dist[lca] + mod)%mod;
- int to_lca = to_root(start) - to_root(lca);
- to_lca = (to_lca + mod) %mod;
- to_lca -= ((long long)dist_to_lca * total_toll.sum(in_out[lca].first))%mod;
- to_lca = (to_lca+mod)%mod;
- to_lca += ((long long) dist_to_lca * to_end_toll)%mod;
- to_lca = (to_lca + mod)%mod;
- return to_lca;
- }
- int find_from_lca(int end, int lca, int to_end_toll)
- {
- int dist_from_lca = (total_dist[end] - total_dist[lca] + mod)%mod;
- int from_lca = from_root(end) - from_root(lca);
- from_lca = (from_lca + mod) % mod;
- from_lca -= ((long long) to_end_toll * total_dist[lca])%mod;
- from_lca = (from_lca+mod)%mod;
- return from_lca;
- }
- long long int solve(int start, int end, int g)
- {
- int first, second;
- tie(first, second) = in_out[start].first > in_out[end].first ? make_tuple(end, start) : make_tuple(start, end);
- int lca = find_lca(first, second);
- int to_end_total_toll = total_toll.sum(in_out[end].first) - total_toll.sum(in_out[lca].first);
- to_end_total_toll = (to_end_total_toll+mod)%mod;
- int to_lca = find_to_lca(start, end, lca, to_end_total_toll);
- int from_lca = find_from_lca(end, lca, to_end_total_toll);
- int total_distance = total_dist[start] + total_dist[end] - ((2 * total_dist[lca])%mod);
- total_distance = (total_distance + mod)%mod;
- return (((long long) total_distance * g)%mod + to_lca + from_lca) % mod;
- }
- int main()
- {
- for (int i = 2; i < 2*maxN; i++) log[i] = log[i/2] +1;
- int n, g;
- //fin >> n >> g;
- cin >> n >> g;
- for (int i = 0; i < n-1; i++)
- {
- int a, b, d, t;
- //cin >> a >> b >> d >> t;
- a = i+1; b = i+2; d = 1000000000; t = 0;
- pair<int, pair<int, int> > some = make_pair(a, make_pair(t, d));
- adj[b].push_back(some);
- some.first = b;
- adj[a].push_back(some);
- }
- calculate_city_info(1, 0);
- for (int i = 0; i < 2*n-1; i++)
- {
- st[i][0] = tour[i];
- }
- for (int len = 1, count = 1; count > 0; len++)
- {
- count = 0;
- for (int start = 0; start + (1<<len) < 2*n-1; start++, count++)
- {
- st[start][len] = higher(st[start][len-1], st[start+(1<<(len-1))][len-1]);
- }
- }
- int q; cin >> q;
- while(q--)
- {
- int type; cin >>type;
- if (type)
- {
- int start, end;
- cin >> start >> end;
- long long int ans = solve(start, end ,g);
- cout << ans << endl;
- }
- else
- {
- int x, y, t;
- cin >> x >> y >> t;
- if (depth[x] < depth[y])
- {
- update_all(y, t);
- }
- else
- {
- update_all(x, t);
- }
- }
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement