Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #ifdef DEBUG
- #define _GLIBCXX_DEBUG
- #endif
- #include <bits/stdc++.h>
- using namespace std;
- typedef long double ld;
- #ifdef DEBUG
- #define eprintf(...) fprintf(stderr, __VA_ARGS__), fflush(stderr)
- #else
- #define eprintf(...) ;
- #endif
- #define sz(x) ((int) (x).size())
- #define TASK "text"
- const int inf = (int) 1.01e9;
- const long long infll = (long long) 1.01e18;
- const ld eps = 1e-9;
- const ld pi = acos((ld) -1);
- #ifdef DEBUG
- mt19937 mrand(300);
- #else
- mt19937 mrand(chrono::steady_clock::now().time_since_epoch().count());
- #endif
- int rnd(int x) {
- return mrand() % x;
- }
- void precalc() {
- }
- const int mod = 998244353;
- int mul(int a, int b) {
- return (long long) a * b % mod;
- }
- void add(int &a, int b) {
- a += b;
- if (a >= mod) {
- a -= mod;
- }
- }
- const int maxn = (int) 2e5 + 5;
- int n, k;
- vector<int> g[maxn];
- vector<int> ids[maxn];
- int c[maxn];
- int d[maxn];
- int cvs[maxn];
- int s, t;
- bool read() {
- if (scanf("%d%d", &n, &k) < 2) {
- return false;
- }
- for (int i = 0; i < n; i++) {
- g[i].clear();
- ids[i].clear();
- }
- for (int i = 0; i < n - 1; i++) {
- int v, u;
- scanf("%d%d%d%d", &v, &u, &c[i], &d[i]);
- v--;
- u--;
- g[v].push_back(u);
- ids[v].push_back(i);
- g[u].push_back(v);
- ids[u].push_back(i);
- }
- for (int i = 0; i < k; i++) {
- scanf("%d", &cvs[i]);
- cvs[i]--;
- }
- cvs[k] = cvs[0];
- scanf("%d%d", &s, &t);
- s--;
- t--;
- return true;
- }
- const int logn = 20;
- int ps[maxn][logn];
- int tin[maxn], tout[maxn], tt;
- vector<int> path;
- int ppos[maxn];
- long long need[maxn];
- long long xs[maxn];
- long long dep[maxn];
- int root[maxn];
- bool getPath(int v, int p, int t) {
- path.push_back(v);
- if (v == t) {
- return true;
- }
- for (int i = 0; i < sz(g[v]); i++) {
- int u = g[v][i];
- if (u == p) {
- continue;
- }
- if (getPath(u, v, t)) {
- return true;
- }
- }
- path.pop_back();
- return false;
- }
- void getDep(int v, int p, int rt) {
- tin[v] = tt++;
- root[v] = rt;
- ps[v][0] = (p == -1 ? v : p);
- for (int i = 1; i < logn; i++) {
- ps[v][i] = ps[ps[v][i - 1]][i - 1];
- }
- for (int i = 0; i < sz(g[v]); i++) {
- int u = g[v][i];
- if (u == p) {
- continue;
- }
- int e = ids[v][i];
- dep[u] = dep[v] + c[e];
- getDep(u, v, rt);
- }
- tout[v] = tt;
- }
- bool isAnc(int v, int u) {
- return tin[v] <= tin[u] && tout[u] <= tout[v];
- }
- int getLca(int v, int u) {
- if (isAnc(v, u)) {
- return v;
- }
- if (isAnc(u, v)) {
- return u;
- }
- for (int i = logn - 1; i >= 0; i--) {
- if (!isAnc(ps[v][i], u)) {
- v = ps[v][i];
- }
- }
- return ps[v][0];
- }
- long long L;
- int clen;
- vector<int> cyc;
- vector<long long> ts;
- vector<pair<long long, long long>> addEvs[2][maxn];
- vector<pair<long long, long long>> delEvs[2][maxn];
- struct node {
- pair<long long, long long> seg;
- long long len;
- long long mxlen;
- long long toShrink;
- int y;
- node *l, *r;
- node(const pair<long long, long long> &_seg): seg(_seg), len(seg.second - seg.first),
- mxlen(len), toShrink(0), y(mrand()), l(0), r(0) {}
- void shrink(long long x) {
- seg.first += x;
- seg.second -= x;
- len -= 2 * x;
- mxlen -= 2 * x;
- toShrink += x;
- }
- void push() {
- if (!toShrink) {
- return;
- }
- for (int it = 0; it < 2; it++) {
- node *u = (!it ? l : r);
- if (u) {
- u->shrink(toShrink);
- }
- }
- toShrink = 0;
- }
- node *recalc() {
- mxlen = len;
- for (int it = 0; it < 2; it++) {
- node *u = (!it ? l : r);
- if (u) {
- mxlen = max(mxlen, u->mxlen);
- }
- }
- return this;
- }
- };
- node *merge(node *l, node *r) {
- if (!l) {
- return r;
- }
- if (!r) {
- return l;
- }
- if (l->y < r->y) {
- l->push();
- l->r = merge(l->r, r);
- return l->recalc();
- } else {
- r->push();
- r->l = merge(l, r->l);
- return r->recalc();
- }
- }
- void split(node *v, long long x, node *&l, node *&r) {
- if (!v) {
- l = 0;
- r = 0;
- return;
- }
- v->push();
- if (x <= v->seg.first) {
- split(v->l, x, l, v->l);
- r = v->recalc();
- } else {
- split(v->r, x, v->r, r);
- l = v->recalc();
- }
- }
- node *del(node *v, const pair<long long, long long> &seg) {
- assert(v);
- v->push();
- if (v->seg == seg) {
- return merge(v->l, v->r);
- }
- assert(v->seg.first != seg.first);
- if (seg.first < v->seg.first) {
- v->l = del(v->l, seg);
- } else {
- v->r = del(v->r, seg);
- }
- return v->recalc();
- }
- node *add(node *v, node *u) {
- if (!v) {
- return u;
- }
- v->push();
- if (u->y < v->y) {
- split(v, u->seg.first, u->l, u->r);
- return u->recalc();
- }
- assert(u->seg.first != v->seg.first);
- if (u->seg.first < v->seg.first) {
- v->l = add(v->l, u);
- } else {
- v->r = add(v->r, u);
- }
- return v->recalc();
- }
- pair<long long, long long> getLst(node *v) {
- assert(v);
- if (!v->r) {
- return v->seg;
- }
- v->push();
- return getLst(v->r);
- }
- bool getT(node *v, long long need, long long &t) {
- if (!v) {
- return false;
- }
- if (v->mxlen < need) {
- return false;
- }
- v->push();
- if (getT(v->l, need, t)) {
- return true;
- }
- if (v->len >= need) {
- t = min(t, v->seg.first);
- return true;
- }
- return getT(v->r, need, t);
- }
- void solve() {
- path.clear();
- assert(getPath(s, -1, t));
- for (int i = 0; i < n; i++) {
- ppos[i] = -1;
- }
- for (int i = 0; i < sz(path); i++) {
- ppos[path[i]] = i;
- }
- xs[0] = 0;
- tt = 0;
- for (int i = 0; i < sz(path); i++) {
- int v = path[i];
- if (i + 1 < sz(path)) {
- int u = path[i + 1];
- for (int it = 0; it < 2; it++) {
- int pos = find(g[v].begin(), g[v].end(), u) - g[v].begin();
- int e = ids[v][pos];
- g[v].erase(g[v].begin() + pos);
- ids[v].erase(ids[v].begin() + pos);
- if (!it) {
- need[i] = d[e];
- xs[i + 1] = xs[i] + c[e];
- }
- swap(v, u);
- }
- }
- dep[v] = 0;
- getDep(v, -1, v);
- }
- {
- cyc.clear();
- ts.clear();
- long long curt = 0;
- for (int i = 0; i < k; i++) {
- int v = cvs[i];
- int u = cvs[i + 1];
- if (root[v] == root[u]) {
- curt += dep[v] + dep[u] - 2 * dep[getLca(v, u)];
- } else {
- curt += dep[v];
- v = root[v];
- cyc.push_back(v);
- ts.push_back(curt);
- curt += dep[u];
- u = root[u];
- curt += abs(xs[ppos[u]] - xs[ppos[v]]);
- }
- }
- L = curt;
- clen = sz(cyc);
- for (int i = 0; i < clen; i++) {
- cyc.push_back(cyc[i]);
- ts.push_back(L + ts[i]);
- }
- }
- for (int i = 0; i + 1 < sz(path); i++) {
- addEvs[0][i].clear();
- addEvs[1][i].clear();
- delEvs[0][i].clear();
- delEvs[1][i].clear();
- }
- if (clen) {
- int spos = 0;
- for (int i = 1; i < clen; i++) {
- if (ppos[cyc[i]] < ppos[cyc[spos]]) {
- spos = i;
- }
- }
- vector<pair<int, long long>> st;
- for (int i = spos; i < spos + clen; i++) {
- int v = cyc[i];
- int u = cyc[i + 1];
- long long curt = ts[i];
- if (ppos[u] < ppos[v]) {
- int w = v;
- while (!st.empty() && ppos[st.back().first] >= ppos[u]) {
- int nw = st.back().first;
- long long d = xs[ppos[w]] - xs[ppos[nw]];
- curt += d;
- long long l = st.back().second, r = curt;
- addEvs[0][ppos[nw]].push_back(make_pair(l, r));
- delEvs[0][ppos[w]].push_back(make_pair(l + d, r - d));
- st.pop_back();
- w = nw;
- }
- if (w != u) {
- long long d = xs[ppos[w]] - xs[ppos[u]];
- curt += d;
- long long l = st.back().second + (xs[ppos[u]] - xs[ppos[st.back().first]]), r = curt;
- addEvs[0][ppos[u]].push_back(make_pair(l, r));
- delEvs[0][ppos[w]].push_back(make_pair(l + d, r - d));
- }
- } else {
- st.push_back(make_pair(v, curt));
- }
- }
- }
- if (clen) {
- int spos = 0;
- for (int i = 1; i < clen; i++) {
- if (ppos[cyc[i]] > ppos[cyc[spos]]) {
- spos = i;
- }
- }
- vector<pair<int, long long>> st;
- for (int i = spos; i < spos + clen; i++) {
- int v = cyc[i];
- int u = cyc[i + 1];
- long long curt = ts[i];
- if (ppos[u] > ppos[v]) {
- int w = v;
- while (!st.empty() && ppos[st.back().first] <= ppos[u]) {
- int nw = st.back().first;
- long long d = xs[ppos[nw]] - xs[ppos[w]];
- curt += d;
- long long l = st.back().second, r = curt;
- delEvs[1][ppos[nw]].push_back(make_pair(l, r));
- addEvs[1][ppos[w]].push_back(make_pair(l + d, r - d));
- st.pop_back();
- w = nw;
- }
- if (w != u) {
- long long d = xs[ppos[u]] - xs[ppos[w]];
- curt += d;
- long long l = st.back().second + (xs[ppos[st.back().first]] - xs[ppos[u]]), r = curt;
- delEvs[1][ppos[u]].push_back(make_pair(l, r));
- addEvs[1][ppos[w]].push_back(make_pair(l + d, r - d));
- }
- } else {
- st.push_back(make_pair(v, curt));
- }
- }
- }
- node *rts[2] = {0, 0};
- int res = 0;
- long long curt = 0;
- for (int i = 0; i + 1 < sz(path); i++) {
- for (int t = 0; t < 2; t++) {
- for (int it = 0; it < sz(delEvs[t][i]); it++) {
- rts[t] = del(rts[t], delEvs[t][i][it]);
- }
- }
- for (int t = 0; t < 2; t++) {
- for (int it = 0; it < sz(addEvs[t][i]); it++) {
- rts[t] = add(rts[t], new node(addEvs[t][i][it]));
- }
- }
- if (rts[0]) {
- rts[0]->shrink(xs[i + 1] - xs[i]);
- } else {
- assert(!rts[1]);
- add(res, need[i] % mod);
- curt = (curt + need[i]) % L;
- continue;
- }
- long long nxtt = curt + L;
- for (int t = 0; t < 2; t++) {
- node *l, *m, *r;
- split(rts[t], curt, l, m);
- split(m, curt + L, m, r);
- if (l) {
- auto lst = getLst(l);
- if (curt + need[i] <= lst.second) {
- nxtt = min(nxtt, curt);
- }
- }
- if (m) {
- auto lst = getLst(m);
- if (curt + L + need[i] <= lst.second) {
- nxtt = min(nxtt, curt);
- }
- }
- long long nxtt0 = nxtt - L;
- if (getT(l, need[i], nxtt0)) {
- nxtt = min(nxtt, nxtt0 + L);
- }
- getT(m, need[i], nxtt);
- nxtt0 = 2 * L;
- if (getT(r, need[i], nxtt0)) {
- nxtt = min(nxtt, nxtt0 - L);
- }
- rts[t] = merge(l, merge(m, r));
- }
- rts[1]->shrink(-(xs[i + 1] - xs[i]));
- if (nxtt >= curt + L) {
- printf("-1\n");
- return;
- }
- nxtt += need[i];
- add(res, (nxtt - curt) % mod);
- curt = nxtt % L;
- }
- printf("%d\n", res);
- }
- int main() {
- precalc();
- #ifdef DEBUG
- assert(freopen(TASK ".in", "r", stdin));
- assert(freopen(TASK ".out", "w", stdout));
- #endif
- int t;
- scanf("%d", &t);
- while (read()) {
- solve();
- #ifdef DEBUG
- eprintf("Time %.2f\n", (double) clock() / CLOCKS_PER_SEC);
- #endif
- }
- return 0;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement