Advertisement
Guest User

Untitled

a guest
Jan 23rd, 2019
387
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 11.04 KB | None | 0 0
  1. #ifdef DEBUG
  2. #define _GLIBCXX_DEBUG
  3. #endif
  4.  
  5. #include <bits/stdc++.h>
  6.  
  7. using namespace std;
  8.  
  9. typedef long double ld;
  10.  
  11. #ifdef DEBUG
  12. #define eprintf(...) fprintf(stderr, __VA_ARGS__), fflush(stderr)
  13. #else
  14. #define eprintf(...) ;
  15. #endif
  16.  
  17. #define sz(x) ((int) (x).size())
  18. #define TASK "text"
  19.  
  20. const int inf = (int) 1.01e9;
  21. const long long infll = (long long) 1.01e18;
  22. const ld eps = 1e-9;
  23. const ld pi = acos((ld) -1);
  24.  
  25. #ifdef DEBUG
  26. mt19937 mrand(300);
  27. #else
  28. mt19937 mrand(chrono::steady_clock::now().time_since_epoch().count());
  29. #endif
  30.  
  31. int rnd(int x) {
  32.   return mrand() % x;
  33. }
  34.  
  35. void precalc() {
  36. }
  37.  
  38. const int mod = 998244353;
  39.  
  40. int mul(int a, int b) {
  41.   return (long long) a * b % mod;
  42. }
  43.  
  44. void add(int &a, int b) {
  45.   a += b;
  46.   if (a >= mod) {
  47.     a -= mod;
  48.   }
  49. }
  50.  
  51. const int maxn = (int) 2e5 + 5;
  52. int n, k;
  53. vector<int> g[maxn];
  54. vector<int> ids[maxn];
  55. int c[maxn];
  56. int d[maxn];
  57. int cvs[maxn];
  58. int s, t;
  59.  
  60. bool read() {
  61.   if (scanf("%d%d", &n, &k) < 2) {
  62.     return false;
  63.   }
  64.   for (int i = 0; i < n; i++) {
  65.     g[i].clear();
  66.     ids[i].clear();
  67.   }
  68.   for (int i = 0; i < n - 1; i++) {
  69.     int v, u;
  70.     scanf("%d%d%d%d", &v, &u, &c[i], &d[i]);
  71.     v--;
  72.     u--;
  73.     g[v].push_back(u);
  74.     ids[v].push_back(i);
  75.     g[u].push_back(v);
  76.     ids[u].push_back(i);
  77.   }
  78.   for (int i = 0; i < k; i++) {
  79.     scanf("%d", &cvs[i]);
  80.     cvs[i]--;
  81.   }
  82.   cvs[k] = cvs[0];
  83.   scanf("%d%d", &s, &t);
  84.   s--;
  85.   t--;
  86.   return true;
  87. }
  88.  
  89. const int logn = 20;
  90. int ps[maxn][logn];
  91. int tin[maxn], tout[maxn], tt;
  92. vector<int> path;
  93. int ppos[maxn];
  94. long long need[maxn];
  95. long long xs[maxn];
  96. long long dep[maxn];
  97. int root[maxn];
  98.  
  99. bool getPath(int v, int p, int t) {
  100.   path.push_back(v);
  101.   if (v == t) {
  102.     return true;
  103.   }
  104.   for (int i = 0; i < sz(g[v]); i++) {
  105.     int u = g[v][i];
  106.     if (u == p) {
  107.       continue;
  108.     }
  109.     if (getPath(u, v, t)) {
  110.       return true;
  111.     }
  112.   }
  113.   path.pop_back();
  114.   return false;
  115. }
  116.  
  117. void getDep(int v, int p, int rt) {
  118.   tin[v] = tt++;
  119.   root[v] = rt;
  120.   ps[v][0] = (p == -1 ? v : p);
  121.   for (int i = 1; i < logn; i++) {
  122.     ps[v][i] = ps[ps[v][i - 1]][i - 1];
  123.   }
  124.   for (int i = 0; i < sz(g[v]); i++) {
  125.     int u = g[v][i];
  126.     if (u == p) {
  127.       continue;
  128.     }
  129.     int e = ids[v][i];
  130.     dep[u] = dep[v] + c[e];
  131.     getDep(u, v, rt);
  132.   }
  133.   tout[v] = tt;
  134. }
  135.  
  136. bool isAnc(int v, int u) {
  137.   return tin[v] <= tin[u] && tout[u] <= tout[v];
  138. }
  139.  
  140. int getLca(int v, int u) {
  141.   if (isAnc(v, u)) {
  142.     return v;
  143.   }
  144.   if (isAnc(u, v)) {
  145.     return u;
  146.   }
  147.   for (int i = logn - 1; i >= 0; i--) {
  148.     if (!isAnc(ps[v][i], u)) {
  149.       v = ps[v][i];
  150.     }
  151.   }
  152.   return ps[v][0];
  153. }
  154.  
  155. long long L;
  156. int clen;
  157. vector<int> cyc;
  158. vector<long long> ts;
  159. vector<pair<long long, long long>> addEvs[2][maxn];
  160. vector<pair<long long, long long>> delEvs[2][maxn];
  161.  
  162. struct node {
  163.   pair<long long, long long> seg;
  164.   long long len;
  165.   long long mxlen;
  166.   long long toShrink;
  167.   int y;
  168.   node *l, *r;
  169.  
  170.   node(const pair<long long, long long> &_seg): seg(_seg), len(seg.second - seg.first),
  171.                                                 mxlen(len), toShrink(0), y(mrand()), l(0), r(0) {}
  172.  
  173.   void shrink(long long x) {
  174.     seg.first += x;
  175.     seg.second -= x;
  176.     len -= 2 * x;
  177.     mxlen -= 2 * x;
  178.     toShrink += x;
  179.   }
  180.  
  181.   void push() {
  182.     if (!toShrink) {
  183.       return;
  184.     }
  185.     for (int it = 0; it < 2; it++) {
  186.       node *u = (!it ? l : r);
  187.       if (u) {
  188.         u->shrink(toShrink);
  189.       }
  190.     }
  191.     toShrink = 0;
  192.   }
  193.  
  194.   node *recalc() {
  195.     mxlen = len;
  196.     for (int it = 0; it < 2; it++) {
  197.       node *u = (!it ? l : r);
  198.       if (u) {
  199.         mxlen = max(mxlen, u->mxlen);
  200.       }
  201.     }
  202.     return this;
  203.   }
  204. };
  205.  
  206. node *merge(node *l, node *r) {
  207.   if (!l) {
  208.     return r;
  209.   }
  210.   if (!r) {
  211.     return l;
  212.   }
  213.   if (l->y < r->y) {
  214.     l->push();
  215.     l->r = merge(l->r, r);
  216.     return l->recalc();
  217.   } else {
  218.     r->push();
  219.     r->l = merge(l, r->l);
  220.     return r->recalc();
  221.   }
  222. }
  223.  
  224. void split(node *v, long long x, node *&l, node *&r) {
  225.   if (!v) {
  226.     l = 0;
  227.     r = 0;
  228.     return;
  229.   }
  230.   v->push();
  231.   if (x <= v->seg.first) {
  232.     split(v->l, x, l, v->l);
  233.     r = v->recalc();
  234.   } else {
  235.     split(v->r, x, v->r, r);
  236.     l = v->recalc();
  237.   }
  238. }
  239.  
  240. node *del(node *v, const pair<long long, long long> &seg) {
  241.   assert(v);
  242.   v->push();
  243.   if (v->seg == seg) {
  244.     return merge(v->l, v->r);
  245.   }
  246.   assert(v->seg.first != seg.first);
  247.   if (seg.first < v->seg.first) {
  248.     v->l = del(v->l, seg);
  249.   } else {
  250.     v->r = del(v->r, seg);
  251.   }
  252.   return v->recalc();
  253. }
  254.  
  255. node *add(node *v, node *u) {
  256.   if (!v) {
  257.     return u;
  258.   }
  259.   v->push();
  260.   if (u->y < v->y) {
  261.     split(v, u->seg.first, u->l, u->r);
  262.     return u->recalc();
  263.   }
  264.   assert(u->seg.first != v->seg.first);
  265.   if (u->seg.first < v->seg.first) {
  266.     v->l = add(v->l, u);
  267.   } else {
  268.     v->r = add(v->r, u);
  269.   }
  270.   return v->recalc();
  271. }
  272.  
  273. pair<long long, long long> getLst(node *v) {
  274.   assert(v);
  275.   if (!v->r) {
  276.     return v->seg;
  277.   }
  278.   v->push();
  279.   return getLst(v->r);
  280. }
  281.  
  282. bool getT(node *v, long long need, long long &t) {
  283.   if (!v) {
  284.     return false;
  285.   }
  286.   if (v->mxlen < need) {
  287.     return false;
  288.   }
  289.   v->push();
  290.   if (getT(v->l, need, t)) {
  291.     return true;
  292.   }
  293.   if (v->len >= need) {
  294.     t = min(t, v->seg.first);
  295.     return true;
  296.   }
  297.   return getT(v->r, need, t);
  298. }
  299.  
  300. void solve() {
  301.   path.clear();
  302.   assert(getPath(s, -1, t));
  303.   for (int i = 0; i < n; i++) {
  304.     ppos[i] = -1;
  305.   }
  306.   for (int i = 0; i < sz(path); i++) {
  307.     ppos[path[i]] = i;
  308.   }
  309.   xs[0] = 0;
  310.   tt = 0;
  311.   for (int i = 0; i < sz(path); i++) {
  312.     int v = path[i];
  313.     if (i + 1 < sz(path)) {
  314.       int u = path[i + 1];
  315.       for (int it = 0; it < 2; it++) {
  316.         int pos = find(g[v].begin(), g[v].end(), u) - g[v].begin();
  317.         int e = ids[v][pos];
  318.         g[v].erase(g[v].begin() + pos);
  319.         ids[v].erase(ids[v].begin() + pos);
  320.         if (!it) {
  321.           need[i] = d[e];
  322.           xs[i + 1] = xs[i] + c[e];
  323.         }
  324.         swap(v, u);
  325.       }
  326.     }
  327.     dep[v] = 0;
  328.     getDep(v, -1, v);
  329.   }
  330.   {
  331.     cyc.clear();
  332.     ts.clear();
  333.     long long curt = 0;
  334.     for (int i = 0; i < k; i++) {
  335.       int v = cvs[i];
  336.       int u = cvs[i + 1];
  337.       if (root[v] == root[u]) {
  338.         curt += dep[v] + dep[u] - 2 * dep[getLca(v, u)];
  339.       } else {
  340.         curt += dep[v];
  341.         v = root[v];
  342.         cyc.push_back(v);
  343.         ts.push_back(curt);
  344.         curt += dep[u];
  345.         u = root[u];
  346.         curt += abs(xs[ppos[u]] - xs[ppos[v]]);
  347.       }
  348.     }
  349.     L = curt;
  350.     clen = sz(cyc);
  351.     for (int i = 0; i < clen; i++) {
  352.       cyc.push_back(cyc[i]);
  353.       ts.push_back(L + ts[i]);
  354.     }
  355.   }
  356.   for (int i = 0; i + 1 < sz(path); i++) {
  357.     addEvs[0][i].clear();
  358.     addEvs[1][i].clear();
  359.     delEvs[0][i].clear();
  360.     delEvs[1][i].clear();
  361.   }
  362.   if (clen) {
  363.     int spos = 0;
  364.     for (int i = 1; i < clen; i++) {
  365.       if (ppos[cyc[i]] < ppos[cyc[spos]]) {
  366.         spos = i;
  367.       }
  368.     }
  369.     vector<pair<int, long long>> st;
  370.     for (int i = spos; i < spos + clen; i++) {
  371.       int v = cyc[i];
  372.       int u = cyc[i + 1];
  373.       long long curt = ts[i];
  374.       if (ppos[u] < ppos[v]) {
  375.         int w = v;
  376.         while (!st.empty() && ppos[st.back().first] >= ppos[u]) {
  377.           int nw = st.back().first;
  378.           long long d = xs[ppos[w]] - xs[ppos[nw]];
  379.           curt += d;
  380.           long long l = st.back().second, r = curt;
  381.           addEvs[0][ppos[nw]].push_back(make_pair(l, r));
  382.           delEvs[0][ppos[w]].push_back(make_pair(l + d, r - d));
  383.           st.pop_back();
  384.           w = nw;
  385.         }
  386.         if (w != u) {
  387.           long long d = xs[ppos[w]] - xs[ppos[u]];
  388.           curt += d;
  389.           long long l = st.back().second + (xs[ppos[u]] - xs[ppos[st.back().first]]), r = curt;
  390.           addEvs[0][ppos[u]].push_back(make_pair(l, r));
  391.           delEvs[0][ppos[w]].push_back(make_pair(l + d, r - d));
  392.         }
  393.       } else {
  394.         st.push_back(make_pair(v, curt));
  395.       }
  396.     }
  397.   }
  398.   if (clen) {
  399.     int spos = 0;
  400.     for (int i = 1; i < clen; i++) {
  401.       if (ppos[cyc[i]] > ppos[cyc[spos]]) {
  402.         spos = i;
  403.       }
  404.     }
  405.     vector<pair<int, long long>> st;
  406.     for (int i = spos; i < spos + clen; i++) {
  407.       int v = cyc[i];
  408.       int u = cyc[i + 1];
  409.       long long curt = ts[i];
  410.       if (ppos[u] > ppos[v]) {
  411.         int w = v;
  412.         while (!st.empty() && ppos[st.back().first] <= ppos[u]) {
  413.           int nw = st.back().first;
  414.           long long d = xs[ppos[nw]] - xs[ppos[w]];
  415.           curt += d;
  416.           long long l = st.back().second, r = curt;
  417.           delEvs[1][ppos[nw]].push_back(make_pair(l, r));
  418.           addEvs[1][ppos[w]].push_back(make_pair(l + d, r - d));
  419.           st.pop_back();
  420.           w = nw;
  421.         }
  422.         if (w != u) {
  423.           long long d = xs[ppos[u]] - xs[ppos[w]];
  424.           curt += d;
  425.           long long l = st.back().second + (xs[ppos[st.back().first]] - xs[ppos[u]]), r = curt;
  426.           delEvs[1][ppos[u]].push_back(make_pair(l, r));
  427.           addEvs[1][ppos[w]].push_back(make_pair(l + d, r - d));
  428.         }
  429.       } else {
  430.         st.push_back(make_pair(v, curt));
  431.       }
  432.     }
  433.   }
  434.   node *rts[2] = {0, 0};
  435.   int res = 0;
  436.   long long curt = 0;
  437.   for (int i = 0; i + 1 < sz(path); i++) {
  438.     for (int t = 0; t < 2; t++) {
  439.       for (int it = 0; it < sz(delEvs[t][i]); it++) {
  440.         rts[t] = del(rts[t], delEvs[t][i][it]);
  441.       }
  442.     }
  443.     for (int t = 0; t < 2; t++) {
  444.       for (int it = 0; it < sz(addEvs[t][i]); it++) {
  445.         rts[t] = add(rts[t], new node(addEvs[t][i][it]));
  446.       }
  447.     }
  448.     if (rts[0]) {
  449.       rts[0]->shrink(xs[i + 1] - xs[i]);
  450.     } else {
  451.       assert(!rts[1]);
  452.       add(res, need[i] % mod);
  453.       curt = (curt + need[i]) % L;
  454.       continue;
  455.     }
  456.     long long nxtt = curt + L;
  457.     for (int t = 0; t < 2; t++) {
  458.       node *l, *m, *r;
  459.       split(rts[t], curt, l, m);
  460.       split(m, curt + L, m, r);
  461.       if (l) {
  462.         auto lst = getLst(l);
  463.         if (curt + need[i] <= lst.second) {
  464.           nxtt = min(nxtt, curt);
  465.         }
  466.       }
  467.       if (m) {
  468.         auto lst = getLst(m);
  469.         if (curt + L + need[i] <= lst.second) {
  470.           nxtt = min(nxtt, curt);
  471.         }
  472.       }
  473.       long long nxtt0 = nxtt - L;
  474.       if (getT(l, need[i], nxtt0)) {
  475.         nxtt = min(nxtt, nxtt0 + L);
  476.       }
  477.       getT(m, need[i], nxtt);
  478.       nxtt0 = 2 * L;
  479.       if (getT(r, need[i], nxtt0)) {
  480.         nxtt = min(nxtt, nxtt0 - L);
  481.       }
  482.       rts[t] = merge(l, merge(m, r));
  483.     }
  484.     rts[1]->shrink(-(xs[i + 1] - xs[i]));
  485.     if (nxtt >= curt + L) {
  486.       printf("-1\n");
  487.       return;
  488.     }
  489.     nxtt += need[i];
  490.     add(res, (nxtt - curt) % mod);
  491.     curt = nxtt % L;
  492.   }
  493.   printf("%d\n", res);
  494. }
  495.  
  496. int main() {
  497.   precalc();
  498. #ifdef DEBUG
  499.   assert(freopen(TASK ".in", "r", stdin));
  500.   assert(freopen(TASK ".out", "w", stdout));
  501. #endif
  502.   int t;
  503.   scanf("%d", &t);
  504.   while (read()) {
  505.     solve();
  506. #ifdef DEBUG
  507.     eprintf("Time %.2f\n", (double) clock() / CLOCKS_PER_SEC);
  508. #endif
  509.   }
  510.   return 0;
  511. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement