# IOI '18 - Highway (51pts)

Jul 27th, 2022
1. #include "highway.h"
2. #include <iostream>
3. #include <algorithm>
4. #include <cassert>
5.
6. using namespace std;
7.
8. vector<vector<pair<int, int>>> g;
9. vector<int> parent, edgeUp, depth;
10.
11. void rootDFS(int node, int d = 0, int eUP = -1, int par = -1) {
12.   parent[node] = par;
13.   edgeUp[node] = eUP;
14.   depth[node] = d;
15.   for (auto [child, e] : g[node]) {
16.     if (child != par) rootDFS(child, d+1, e, node);
17.   }
18. }
19. void rootTree(int root, int N) {
20.   parent.resize(N);
21.   edgeUp.resize(N);
22.   depth.resize(N);
23.   rootDFS(root);
24. }
25. vector<int> nodesDepth(int d) {
26.   vector<int> v;
27.   for (int i = 0; i < (int)depth.size(); ++i) {
28.     if (depth[i] == d) v.push_back(i);
29.   }
30.   return v;
31. }
32.
33. void find_pair(int N, vector<int> U, vector<int> V, int A, int B) {
34.   g.resize(N);
35.   int M = U.size();
36.   for (int i = 0; i < M; ++i) {
37.     g[U[i]].emplace_back(V[i], i);
38.     g[V[i]].emplace_back(U[i], i);
39.   }
40.   vector<int> w(M);
41.   long long toll = ask(w);
42.
43.   rootTree(0, N);
44.   // Find the depth of the lower of the two nodes
45.   int l = 0, r = *max_element(depth.begin(), depth.end()); // l < d <= r
46.   while (l+1 < r) {
47.     int mid = (l+r)/2;
48.     // set all nodes lower than depth mid to have B edges leading up
49.     for (int i = 0; i < N; ++i) {
50.       if (depth[i] > mid) w[edgeUp[i]] = 1;
51.     }
52.     long long res = ask(w);
53.     for (int i = 0; i < N; ++i) {
54.       if (depth[i] > mid) w[edgeUp[i]] = 0;
55.     }
56.     if (res == toll) r = mid;
57.     else l = mid;
58.   }
59.   int lowerDepth = r;
60.
61.   auto findWithDepth = [&toll, &w](int depth) {
62.     vector<int> candidates = nodesDepth(depth);
63.     int l = 0, r = (int)candidates.size(); // in range [l, r)
64.     while (l+1 < r) {
65.       int mid = (l+r)/2;
66.       for (int i = mid; i < r; ++i) w[edgeUp[candidates[i]]] = 1, assert(edgeUp[candidates[i]] >= 0);
67.       long long res = ask(w);
68.       for (int i = mid; i < r; ++i) w[edgeUp[candidates[i]]] = 0;
69.       if (res == toll) r = mid;
70.       else l = mid;
71.     }
72.     return candidates[l];
73.   };
74.
75.   // Find the lower of the two nodes (or one of them if they have the same depth)
76.   int S = findWithDepth(lowerDepth);
77.
78.   // Find the other node by rooting the tree at the first node
79.   rootTree(S, N);
80.   assert(toll%A==0);
81.   int T = findWithDepth(toll/A);
82.   assert(depth[T] == toll/A);
83.   answer(S, T);
84. }
