Advertisement
Guest User

Untitled

a guest
Aug 5th, 2022
184
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.85 KB | None | 0 0
  1. #include <iostream>
  2. #include <vector>
  3. #include <set>
  4. #include <tuple>
  5. #include <fstream>
  6. using namespace std;
  7.  
  8. const int maxN = 1e5 + 1;
  9. const int mod = 1e9 + 7;
  10.  
  11. struct fenwick_tree
  12. {
  13. int n;
  14. vector<long long int> bit, actual;
  15.  
  16.  
  17. fenwick_tree(int n)
  18. {
  19. this -> n = n;
  20. bit.assign(n, 0);
  21. actual.assign(n, 0);
  22.  
  23. }
  24. void update(int pos, long long int replacement)
  25. {
  26. int delta = ((replacement - actual[pos])%mod + mod)%mod;
  27. for(int at = pos; at < n; at |= at + 1)
  28. {
  29. bit[at] += delta;
  30. bit[at]%= mod;
  31. }
  32. actual[pos] = replacement;
  33. }
  34. int sum(int r)
  35. {
  36. int res = 0;
  37. for (int at = r; at >= 0; at = (at&(at+1))-1)
  38. {
  39. res += bit[at];
  40. res %= mod;
  41. }
  42. return res;
  43. }
  44. };
  45.  
  46. vector<pair<int, pair<int, int>>> adj[maxN];
  47. int tour[2*maxN], node_toll[maxN], total_dist[maxN], depth[maxN], par[maxN];
  48. fenwick_tree total_toll(2*maxN), to_node(2*maxN), from_node(2*maxN);
  49. pair<int, int> in_out[maxN];
  50. int st[2*maxN][20], log[2*maxN];
  51.  
  52. int tour_counter = 0;
  53.  
  54. void update_to_node(int root, int new_toll)
  55. {
  56. long long int replacement = (long long) total_dist[par[root]] * new_toll;
  57. replacement %= mod;
  58. to_node.update(in_out[root].first, replacement);
  59. to_node.update(in_out[root].second, -replacement);
  60. }
  61. void update_from_node(int root, int new_toll)
  62. {
  63. long long int replacement = (long long) total_dist[root]*new_toll;
  64. replacement %= mod;
  65. from_node.update(in_out[root].first, replacement);
  66. from_node.update(in_out[root].second, -replacement);
  67. }
  68. void update_total_toll(int root, int new_toll)
  69. {
  70. total_toll.update(in_out[root].first, new_toll);
  71. total_toll.update(in_out[root].second, -new_toll);
  72. node_toll[root] = new_toll;
  73. }
  74.  
  75. void update_all(int root, int toll)
  76. {
  77. update_to_node(root, toll);
  78. update_from_node(root, toll);
  79. update_total_toll(root, toll);
  80. }
  81.  
  82. int toll(pair<int, pair<int, int>> a)
  83. {
  84. return a.second.first;
  85. }
  86.  
  87. int dist(pair<int, pair<int, int> > a)
  88. {
  89. return a.second.second;
  90. }
  91. int child(pair<int, pair<int, int> > a)
  92. {
  93. return a.first;
  94. }
  95.  
  96. void calculate_city_info(int root, int tolls)
  97. {
  98. cout << "first" << endl;
  99. tour[tour_counter] = root;
  100. in_out[root].first = tour_counter;
  101. tour_counter++;
  102.  
  103.  
  104.  
  105. for (pair<int, pair<int, int> > edge: adj[root])
  106. {
  107. if (child(edge) != par[root])
  108. {
  109. par[child(edge)] = root;
  110. depth[child(edge)] = depth[root] + 1;
  111.  
  112. total_dist[child(edge)] = total_dist[root] + dist(edge); total_dist[child(edge)] %= mod;
  113. cout << "Second" << endl;
  114. calculate_city_info(child(edge), toll(edge));
  115. tour[tour_counter] = root;
  116. tour_counter++;
  117. }
  118. }
  119. in_out[root].second = tour_counter;
  120. update_all(root, tolls);
  121. }
  122.  
  123. int higher(int a, int b)
  124. {
  125. return depth[a] < depth[b] ? a: b;
  126. }
  127.  
  128. int find_lca(int first, int second)
  129. {
  130. int length = in_out[second].first - in_out[first].first + 1;
  131. int one = st[in_out[first].first][log[length]];
  132. int next_start = in_out[second].first - (1 << log[length]) + 1;
  133. int two = st[next_start][log[length]];
  134. return higher(one, two);
  135. }
  136.  
  137. int to_root(int vertex)
  138. {
  139. int minus = from_node.sum(in_out[vertex].first);
  140. minus = (minus+mod)%mod;
  141. int total = ((long long) total_dist[vertex] * total_toll.sum(in_out[vertex].first))%mod;
  142. total = (total+mod)%mod;
  143. return (total - minus + mod)%mod;
  144. }
  145. int from_root(int vertex)
  146. {
  147. return (to_node.sum(in_out[vertex].first)+mod)%mod;
  148. }
  149.  
  150. int find_to_lca(int start, int end, int lca, int to_end_toll)
  151. {
  152. int dist_to_lca = (total_dist[start] - total_dist[lca] + mod)%mod;
  153. int to_lca = to_root(start) - to_root(lca);
  154. to_lca = (to_lca + mod) %mod;
  155. to_lca -= ((long long)dist_to_lca * total_toll.sum(in_out[lca].first))%mod;
  156. to_lca = (to_lca+mod)%mod;
  157. to_lca += ((long long) dist_to_lca * to_end_toll)%mod;
  158. to_lca = (to_lca + mod)%mod;
  159. return to_lca;
  160. }
  161.  
  162. int find_from_lca(int end, int lca, int to_end_toll)
  163. {
  164. int dist_from_lca = (total_dist[end] - total_dist[lca] + mod)%mod;
  165. int from_lca = from_root(end) - from_root(lca);
  166. from_lca = (from_lca + mod) % mod;
  167. from_lca -= ((long long) to_end_toll * total_dist[lca])%mod;
  168. from_lca = (from_lca+mod)%mod;
  169. return from_lca;
  170. }
  171.  
  172. long long int solve(int start, int end, int g)
  173. {
  174. int first, second;
  175. tie(first, second) = in_out[start].first > in_out[end].first ? make_tuple(end, start) : make_tuple(start, end);
  176. int lca = find_lca(first, second);
  177. int to_end_total_toll = total_toll.sum(in_out[end].first) - total_toll.sum(in_out[lca].first);
  178. to_end_total_toll = (to_end_total_toll+mod)%mod;
  179. int to_lca = find_to_lca(start, end, lca, to_end_total_toll);
  180. int from_lca = find_from_lca(end, lca, to_end_total_toll);
  181.  
  182. int total_distance = total_dist[start] + total_dist[end] - ((2 * total_dist[lca])%mod);
  183. total_distance = (total_distance + mod)%mod;
  184.  
  185. return (((long long) total_distance * g)%mod + to_lca + from_lca) % mod;
  186. }
  187.  
  188. int main()
  189. {
  190. for (int i = 2; i < 2*maxN; i++) log[i] = log[i/2] +1;
  191. int n, g;
  192. //fin >> n >> g;
  193. cin >> n >> g;
  194.  
  195. for (int i = 0; i < n-1; i++)
  196. {
  197. int a, b, d, t;
  198. //cin >> a >> b >> d >> t;
  199. a = i+1; b = i+2; d = 1000000000; t = 0;
  200. pair<int, pair<int, int> > some = make_pair(a, make_pair(t, d));
  201. adj[b].push_back(some);
  202. some.first = b;
  203. adj[a].push_back(some);
  204. }
  205. calculate_city_info(1, 0);
  206. for (int i = 0; i < 2*n-1; i++)
  207. {
  208. st[i][0] = tour[i];
  209. }
  210. for (int len = 1, count = 1; count > 0; len++)
  211. {
  212. count = 0;
  213. for (int start = 0; start + (1<<len) < 2*n-1; start++, count++)
  214. {
  215. st[start][len] = higher(st[start][len-1], st[start+(1<<(len-1))][len-1]);
  216. }
  217. }
  218. int q; cin >> q;
  219. while(q--)
  220. {
  221. int type; cin >>type;
  222. if (type)
  223. {
  224. int start, end;
  225. cin >> start >> end;
  226. long long int ans = solve(start, end ,g);
  227. cout << ans << endl;
  228. }
  229. else
  230. {
  231. int x, y, t;
  232. cin >> x >> y >> t;
  233. if (depth[x] < depth[y])
  234. {
  235. update_all(y, t);
  236. }
  237. else
  238. {
  239. update_all(x, t);
  240. }
  241. }
  242. }
  243. }
  244.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement