Advertisement
Guest User

MyCode

a guest
Aug 5th, 2022
243
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 5.80 KB | None | 0 0
  1. #include <iostream>
  2. #include <vector>
  3. #include <set>
  4. #include <tuple>
  5. #include <fstream>
  6. using namespace std;
  7.  
  8. const int maxN = 1e5 + 1;
  9. const int mod = 1e9 + 7;
  10.  
  11. struct fenwick_tree
  12. {
  13. int n;
  14. vector<long long int> bit, actual;
  15.  
  16.  
  17. fenwick_tree(int n)
  18. {
  19. this -> n = n;
  20. bit.assign(n, 0);
  21. actual.assign(n, 0);
  22.  
  23. }
  24. void update(int pos, long long int replacement)
  25. {
  26. int delta = ((replacement - actual[pos])%mod + mod)%mod;
  27. for(int at = pos; at < n; at |= at + 1)
  28. {
  29. bit[at] += delta;
  30. bit[at]%= mod;
  31. }
  32. actual[pos] = replacement;
  33. }
  34. int sum(int r)
  35. {
  36. int res = 0;
  37. for (int at = r; at >= 0; at = (at&(at+1))-1)
  38. {
  39. res += bit[at];
  40. res %= mod;
  41. }
  42. return res;
  43. }
  44. };
  45.  
  46. vector<pair<int, pair<int, int>>> adj[maxN];
  47. int tour[2*maxN], node_toll[maxN], total_dist[maxN], depth[maxN], par[maxN];
  48. fenwick_tree total_toll(2*maxN), to_node(2*maxN), from_node(2*maxN);
  49. pair<int, int> in_out[maxN];
  50. int st[2*maxN][20], log[2*maxN];
  51.  
  52. int tour_counter = 0;
  53.  
  54. void update_to_node(int root, int new_toll)
  55. {
  56. long long int replacement = (long long) total_dist[par[root]] * new_toll;
  57. replacement %= mod;
  58. to_node.update(in_out[root].first, replacement);
  59. to_node.update(in_out[root].second, -replacement);
  60. }
  61. void update_from_node(int root, int new_toll)
  62. {
  63. long long int replacement = (long long) total_dist[root]*new_toll;
  64. replacement %= mod;
  65. from_node.update(in_out[root].first, replacement);
  66. from_node.update(in_out[root].second, -replacement);
  67. }
  68. void update_total_toll(int root, int new_toll)
  69. {
  70. total_toll.update(in_out[root].first, new_toll);
  71. total_toll.update(in_out[root].second, -new_toll);
  72. node_toll[root] = new_toll;
  73. }
  74.  
  75. void update_all(int root, int toll)
  76. {
  77. update_to_node(root, toll);
  78. update_from_node(root, toll);
  79. update_total_toll(root, toll);
  80. }
  81.  
  82. int toll(pair<int, pair<int, int>> a)
  83. {
  84. return a.second.first;
  85. }
  86.  
  87. int dist(pair<int, pair<int, int> > a)
  88. {
  89. return a.second.second;
  90. }
  91. int child(pair<int, pair<int, int> > a)
  92. {
  93. return a.first;
  94. }
  95.  
  96. void calculate_city_info(int root, int tolls)
  97. {
  98. tour[tour_counter] = root;
  99. in_out[root].first = tour_counter;
  100. tour_counter++;
  101.  
  102.  
  103.  
  104. for (pair<int, pair<int, int> > edge: adj[root])
  105. {
  106. if (child(edge) != par[root])
  107. {
  108. par[child(edge)] = root;
  109. depth[child(edge)] = depth[root] + 1;
  110.  
  111. total_dist[child(edge)] = total_dist[root] + dist(edge); total_dist[child(edge)] %= mod;
  112.  
  113. calculate_city_info(child(edge), toll(edge));
  114. tour[tour_counter] = root;
  115. tour_counter++;
  116. }
  117. }
  118. in_out[root].second = tour_counter;
  119. update_all(root, tolls);
  120. }
  121.  
  122. int higher(int a, int b)
  123. {
  124. return depth[a] < depth[b] ? a: b;
  125. }
  126.  
  127. int find_lca(int first, int second)
  128. {
  129. int length = in_out[second].first - in_out[first].first + 1;
  130. int one = st[in_out[first].first][log[length]];
  131. int next_start = in_out[second].first - (1 << log[length]) + 1;
  132. int two = st[next_start][log[length]];
  133. return higher(one, two);
  134. }
  135.  
  136. int to_root(int vertex)
  137. {
  138. int minus = from_node.sum(in_out[vertex].first);
  139. minus = (minus+mod)%mod;
  140. int total = ((long long) total_dist[vertex] * total_toll.sum(in_out[vertex].first))%mod;
  141. total = (total+mod)%mod;
  142. return (total - minus + mod)%mod;
  143. }
  144. int from_root(int vertex)
  145. {
  146. return (to_node.sum(in_out[vertex].first)+mod)%mod;
  147. }
  148.  
  149. int find_to_lca(int start, int end, int lca, int to_end_toll)
  150. {
  151. int dist_to_lca = (total_dist[start] - total_dist[lca] + mod)%mod;
  152. int to_lca = to_root(start) - to_root(lca);
  153. to_lca = (to_lca + mod) %mod;
  154. to_lca -= ((long long)dist_to_lca * total_toll.sum(in_out[lca].first))%mod;
  155. to_lca = (to_lca+mod)%mod;
  156. to_lca += ((long long) dist_to_lca * to_end_toll)%mod;
  157. to_lca = (to_lca + mod)%mod;
  158. return to_lca;
  159. }
  160.  
  161. int find_from_lca(int end, int lca, int to_end_toll)
  162. {
  163. int dist_from_lca = (total_dist[end] - total_dist[lca] + mod)%mod;
  164. int from_lca = from_root(end) - from_root(lca);
  165. from_lca = (from_lca + mod) % mod;
  166. from_lca -= ((long long) to_end_toll * total_dist[lca])%mod;
  167. from_lca = (from_lca+mod)%mod;
  168. return from_lca;
  169. }
  170.  
  171. long long int solve(int start, int end, int g)
  172. {
  173. int first, second;
  174. tie(first, second) = in_out[start].first > in_out[end].first ? make_tuple(end, start) : make_tuple(start, end);
  175. int lca = find_lca(first, second);
  176. int to_end_total_toll = total_toll.sum(in_out[end].first) - total_toll.sum(in_out[lca].first);
  177. to_end_total_toll = (to_end_total_toll+mod)%mod;
  178. int to_lca = find_to_lca(start, end, lca, to_end_total_toll);
  179. int from_lca = find_from_lca(end, lca, to_end_total_toll);
  180.  
  181. int total_distance = total_dist[start] + total_dist[end] - ((2 * total_dist[lca])%mod);
  182. total_distance = (total_distance + mod)%mod;
  183.  
  184. return (((long long) total_distance * g)%mod + to_lca + from_lca) % mod;
  185. }
  186.  
  187. int main()
  188. {
  189. for (int i = 2; i < 2*maxN; i++) log[i] = log[i/2] +1;
  190. int n, g;
  191. //fin >> n >> g;
  192. cin >> n >> g;
  193.  
  194. for (int i = 0; i < n-1; i++)
  195. {
  196. int a, b, d, t;
  197. cin >> a >> b >> d >> t;
  198. //a = i+1; b = i+2; d = 1000000000; t = 0;
  199. pair<int, pair<int, int> > some = make_pair(a, make_pair(t, d));
  200. adj[b].push_back(some);
  201. some.first = b;
  202. adj[a].push_back(some);
  203. }
  204. calculate_city_info(1, 0);
  205. for (int i = 0; i < 2*n-1; i++)
  206. {
  207. st[i][0] = tour[i];
  208. }
  209. for (int len = 1, count = 1; count > 0; len++)
  210. {
  211. count = 0;
  212. for (int start = 0; start + (1<<len) < 2*n-1; start++, count++)
  213. {
  214. st[start][len] = higher(st[start][len-1], st[start+(1<<(len-1))][len-1]);
  215. }
  216. }
  217. int q; cin >> q;
  218. while(q--)
  219. {
  220. int type; cin >>type;
  221. if (type)
  222. {
  223. int start, end;
  224. cin >> start >> end;
  225. long long int ans = solve(start, end ,g);
  226. cout << ans << endl;
  227. }
  228. else
  229. {
  230. int x, y, t;
  231. cin >> x >> y >> t;
  232. if (depth[x] < depth[y])
  233. {
  234. update_all(y, t);
  235. }
  236. else
  237. {
  238. update_all(x, t);
  239. }
  240. }
  241. }
  242. }
  243.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement