Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <iostream>
- #include <queue>
- #include <vector>
- #include <assert.h>
- using namespace std;
- struct Stall {
- Stall *parent = NULL;
- vector<Stall *> nexts;
- int dist = -1;
- int flow = 0;
- int startflow = 0;
- int stopflow = 0;
- int children = 0;
- vector<Stall *> aboves;
- };
- int log2(int n) {
- int result = -1;
- while (n) {
- result++;
- n >>= 1;
- }
- return result;
- }
- Stall *get2nabove(Stall *stall, int n) {
- Stall *result;
- if (n < stall->aboves.size()) {
- result = stall->aboves[n];
- } else {
- if (n == 0) {
- result = stall->parent;
- } else {
- result = get2nabove(get2nabove(stall, n - 1), n - 1);
- }
- stall->aboves.push_back(result);
- }
- return result;
- }
- Stall *LCA(Stall *a, Stall *b) {
- if (b->dist < a->dist)
- return LCA(b, a);
- int afromb = b->dist - a->dist;
- for (int i = 0; afromb; afromb >>= 1, i++) {
- if (afromb % 2 == 0) continue;
- b = get2nabove(b, i);
- }
- assert(a->dist == b->dist);
- int rbegin = 0;
- while (a->dist - rbegin > 1) {
- int toshift = log2(a->dist - rbegin) - 1;
- Stall *na = get2nabove(a, toshift);
- Stall *nb = get2nabove(b, toshift);
- assert(na->dist == nb->dist);
- if (na == nb)
- rbegin = na->dist + 1;
- else
- a = na, b = nb;
- }
- while (a != b) {
- a = a->parent;
- b = b->parent;
- }
- assert(a == b);
- return a;
- }
- int main() {
- FILE *fin = fopen("maxflow.in", "r");
- FILE *fout = fopen("maxflow.out", "w");
- int n, k;
- fscanf(fin, "%d %d", &n, &k);
- vector<Stall> stalls(n);
- for (int i = 0; i < n - 1; i++) {
- int x, y;
- fscanf(fin, "%d %d", &x, &y);
- x -= 1;
- y -= 1;
- stalls[x].nexts.push_back(&stalls[y]);
- stalls[y].nexts.push_back(&stalls[x]);
- }
- stalls[0].dist = 0;
- queue<Stall *> bfs;
- bfs.push(&stalls[0]);
- Stall *laststall = &stalls[0]; // Dummy value; should never be used.
- while (bfs.size()) {
- Stall *currstall = bfs.front();
- bfs.pop();
- for (Stall *nextstall : currstall->nexts) {
- if (nextstall->dist == -1) {
- nextstall->dist = currstall->dist + 1;
- bfs.push(nextstall);
- }
- }
- if (!bfs.size()) {
- laststall = currstall;
- }
- }
- laststall->dist = 0;
- bfs.push(laststall);
- Stall *laststall2 = laststall; // Dummy value; should never be used.
- while (bfs.size()) {
- Stall *currstall = bfs.front();
- bfs.pop();
- for (Stall *nextstall : currstall->nexts) {
- if (nextstall->parent == NULL) {
- nextstall->parent = currstall;
- nextstall->dist = currstall->dist + 1;
- bfs.push(nextstall);
- }
- }
- if (!bfs.size()) {
- laststall2 = currstall;
- }
- }
- assert(n == 1 || (laststall != laststall2));
- Stall *root = laststall2;
- for (int i = 0; i < laststall2->dist / 2; i++) {
- root = root->parent;
- }
- for (int i = 0; i < n; i++) {
- stalls[i].dist = -1;
- }
- root->parent = NULL;
- root->dist = 0;
- bfs.push(root);
- while (bfs.size()) {
- Stall *currstall = bfs.front();
- bfs.pop();
- for (Stall *nextstall : currstall->nexts) {
- if (nextstall->dist == -1) {
- nextstall->parent = currstall;
- nextstall->dist = currstall->dist + 1;
- bfs.push(nextstall);
- }
- }
- }
- for (int ki = 0; ki < k; ki++) {
- int s, t;
- fscanf(fin, "%d %d", &s, &t);
- stalls[s - 1].startflow++;
- stalls[t - 1].startflow++;
- Stall *lca = LCA(&stalls[s - 1], &stalls[t - 1]);
- lca->stopflow++;
- }
- queue<Stall *> q;
- for (int i = 0; i < n; i++) {
- if (stalls[i].nexts.size() == 1) {
- q.push(&stalls[i]);
- }
- }
- while (q.size()) {
- Stall *currstall = q.front();
- q.pop();
- currstall->flow += currstall->startflow - currstall->stopflow;
- if (currstall != root) {
- currstall->parent->flow += currstall->flow - currstall->stopflow;
- currstall->parent->children++;
- if (currstall->parent->children == currstall->parent->nexts.size() - 1) {
- q.push(currstall->parent);
- }
- }
- }
- /*for (int ki = 0; ki < k; ki++) {
- int s, t;
- fscanf(fin, "%d %d", &s, &t);
- Stall *currstall1 = &stalls[s - 1], *currstall2 = &stalls[t - 1];
- currstall1->flow--;
- currstall2->flow--;
- while (currstall1->dist > currstall2->dist) {
- currstall1 = currstall1->parent;
- currstall1->flow++;
- }
- while (currstall2->dist > currstall1->dist) {
- currstall2 = currstall2->parent;
- currstall2->flow++;
- }
- while (currstall1 != currstall2) {
- if (currstall1 != root) {
- currstall1 = currstall1->parent;
- currstall1->flow++;
- }
- if (currstall2 != root) {
- currstall2 = currstall2->parent;
- currstall2->flow++;
- }
- }
- currstall1->flow--;
- }*/
- int maxflow = 0;
- for (int i = 0; i < n; i++) {
- maxflow = max(maxflow, stalls[i].flow);
- }
- fprintf(fout, "%d\n", maxflow);
- return 0;
- }
Add Comment
Please, Sign In to add comment