Advertisement
FyanRu

SPOJ-COT3

May 26th, 2024
449
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 3.24 KB | None | 0 0
  1. #include <algorithm>
  2. #include <array>
  3. #include <bitset>
  4. #include <cassert>
  5. #include <chrono>
  6. #include <complex>
  7. #include <cstdio>
  8. #include <cstring>
  9. #include <deque>
  10. #include <iomanip>
  11. #include <iostream>
  12. #include <iterator>
  13. #include <list>
  14. #include <map>
  15. #include <memory>
  16. #include <numeric>
  17. #include <queue>
  18. #include <random>
  19. #include <set>
  20. #include <stack>
  21. #include <string>
  22. #include <tuple>
  23. #include <vector>
  24. using namespace std;
  25. #define int long long
  26. #define all(x) begin(x), end(x)
  27. #define sz(x) (int) (x).size()
  28.  
  29. const int MXN = 5e6;
  30.  
  31. int su[MXN], le[MXN], ri[MXN], lz[MXN], de[MXN];
  32. int nx = 1;
  33. vector<int> vls[MXN];
  34.  
  35. int create_node(int d = 17) {
  36.     de[nx] = d; nx += 1;
  37.     assert(nx < MXN);
  38.     return nx-1;
  39. }
  40. void xor_all(int i, int v) {
  41.     lz[i] ^= v;
  42. }
  43. void push_node(int i) {
  44.     if (!de[i]) return;
  45.     if ((1ll<<(de[i]-1)) & lz[i]) {
  46.         swap(le[i], ri[i]);
  47.     }
  48.     if (le[i]) lz[le[i]] ^= lz[i];
  49.     if (ri[i]) lz[ri[i]] ^= lz[i];
  50.     lz[i] = 0;
  51. }
  52. void pull_node(int i) {
  53.     if (!de[i]) return;
  54.     push_node(i);
  55.     su[i] = su[le[i]] + su[ri[i]];
  56. }
  57. int insert_value(int v, int ci, int cl = 0, int cr = (1ll<<17)-1) {
  58.     pull_node(ci);
  59.     if (v == cl && v == cr) {
  60.         if (su[ci]) return 0;
  61.         su[ci] = 1;
  62.         return 1;
  63.     } else {
  64.         int cm = (cl+cr)/2;
  65.         int res = 0;
  66.         if (v <= cm) {
  67.             if (!le[ci]) le[ci] = create_node(de[ci]-1);
  68.             res = insert_value(v, le[ci], cl, cm);
  69.         } else {
  70.             if (!ri[ci]) ri[ci] = create_node(de[ci]-1);
  71.             res = insert_value(v, ri[ci], cm+1, cr);
  72.         }
  73.         pull_node(ci);
  74.         return res;
  75.     }
  76. }
  77. void add_value(int ci, int v) {
  78.     if (insert_value(v,ci)) {
  79.         vls[ci].push_back(v);
  80.     }
  81. }
  82. int get_mex(int ci) {
  83.     if (!de[ci]) return su[ci];
  84.     int nc = (1ll<<de[ci]);
  85.     pull_node(ci);
  86.     if (su[ci] == nc) return nc;
  87.     int lmx = get_mex(le[ci]);
  88.     if (lmx == nc/2) {
  89.         return nc/2+get_mex(ri[ci]);
  90.     }
  91.     return lmx;
  92. }
  93. int merge(int a, int b) {
  94.     if (!a) return b;
  95.     if (!b) return a;
  96.     pull_node(a);
  97.     pull_node(b);
  98.     le[a] = merge(le[a], le[b]);
  99.     ri[a] = merge(ri[a], ri[b]);
  100.     pull_node(a);
  101.     return a;
  102. }
  103.  
  104. int n, c[MXN], rt[MXN], sg[MXN], ans[MXN];
  105. vector<int> adj[MXN];
  106.  
  107. void dfs(int v, int p) {
  108.     if (find(all(adj[v]),p) != end(adj[v])) {
  109.         adj[v].erase(find(all(adj[v]),p));
  110.     }
  111.     rt[v] = create_node();
  112.     int txor = 0; //total xor
  113.     for (int u : adj[v]) {
  114.         assert(u != p);
  115.         dfs(u, v);
  116.         txor ^= sg[u];
  117.     }
  118. //  cout << v+1 << ": " << txor << "\n";
  119.     for (int u : adj[v]) {
  120.         xor_all(rt[u], txor^sg[u]);
  121.         rt[v] = merge(rt[v], rt[u]);
  122.     }
  123.  
  124.     if (!c[v]) {
  125.         insert_value(txor, rt[v]);
  126.     }
  127.     sg[v] = get_mex(rt[v]);
  128. }
  129.  
  130. void dfs2(int v, int p) {
  131.     int txor = 0;
  132.     for (int i : adj[v]) {
  133.         txor ^= sg[i];
  134.     }
  135.     ans[v] = sg[v]^ans[p]^txor;
  136.     for (int i : adj[v]) {
  137.         dfs2(i, v);
  138.     }
  139. }
  140.  
  141. signed main() {
  142.     ios::sync_with_stdio(false); cin.tie(nullptr);
  143.  
  144.     cin >> n;
  145.     for (int i = 0; i < n; i++) cin >> c[i];
  146.     for (int i = 1; i < n; i++) {
  147.         int u, v; cin >> u >> v; --u; --v;
  148.         adj[u].push_back(v); adj[v].push_back(u);
  149.     }
  150.     dfs(0,0);
  151. /*  cout << "\n\n";
  152.     for (int i = 0; i < n; i++) {
  153.         cout << i+1 << ": " << sg[i] << "\n";
  154.     }
  155.     cout << "\n\n";*/
  156.     ans[0] = sg[0];
  157.     dfs2(0,0);
  158.     if (!sg[0]) cout << -1 << "\n";
  159.     for (int i = 0; i < n; i++) {
  160.         if (!c[i] && !ans[i]) cout << i+1 << "\n";
  161.     }
  162.    
  163.     return 0;
  164. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement