Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include<stdio.h>
- #include<iostream>
- #include<vector>
- #include<cmath>
- #include<algorithm>
- #include<memory.h>
- #include<map>
- #include<set>
- #include<queue>
- #include<list>
- #include<sstream>
- #include<cstring>
- #include<numeric>
- using namespace std;
- const int N = 6e3;
- vector<int> g[N];
- int n;
- int ans[N];
- int root;
- int parent[N];
- bool used[N];
- int order[N];
- int lvl[N];
- int first[N];
- int id = 0;
- int pcnt[N];
- void addEdge(int a, int b) {
- g[a].push_back(b);
- g[b].push_back(a);
- }
- bool pu[N];
- void p_dfs(int v) {
- pu[v] = true;
- for (auto c : g[v]) {
- if (!pu[c]) {
- p_dfs(c);
- pcnt[v] += pcnt[c] + 1;
- }
- }
- }
- void dfs(int v, int path, int s) {
- if (first[v] == -1) {
- first[v] = id;
- }
- order[id++] = v;
- lvl[v] = s;
- ans[v] = path;
- used[v] = true;
- for (auto & i : g[v]) {
- if (!used[i]) {
- parent[i] = v;
- dfs(i, path + 1, s + 1);
- order[id++] = v;
- }
- }
- }
- struct t {
- int val, num;
- t() : val(), num() {}
- t(int v, int n) : val(v), num(n) {}
- };
- struct interval_tree {
- private:
- t tree[4 * N];
- public:
- void build(int v, int tl, int tr) {
- if (tl == tr) {
- tree[v] = t(lvl[order[tl]], order[tl]);
- }
- else {
- int tm = (tl + tr) / 2;
- build(v * 2, tl, tm);
- build(v * 2 + 1, tm + 1, tr);
- if (tree[v * 2].val > tree[v * 2 + 1].val) {
- tree[v] = tree[v * 2 + 1];
- }
- else {
- tree[v] = tree[v * 2];
- }
- }
- }
- t qmin(int v, int tl, int tr, int l, int r) {
- if (l > r)
- return t(INT32_MAX, -1);
- if (l == tl && r == tr)
- return tree[v];
- int tm = (tl + tr) / 2;
- t left = qmin(v * 2, tl, tm, l, min(r, tm));
- t right = qmin(v * 2 + 1, tm + 1, tr, max(l, tm + 1), r);
- if (left.val > right.val) {
- return right;
- }
- return left;
- }
- };
- interval_tree tree;
- int main() {
- //freopen("input.txt", "r", stdin);
- scanf("%d", &n);
- for (int i = 0; i < n - 1; i++) {
- int a, b;
- scanf("%d%d", &a, &b);
- addEdge(a - 1, b - 1);
- }
- root = 0;
- fill_n(first, N, -1);
- first[root] = 0;
- dfs(root, 0, 0);
- tree.build(1, 0, id - 1);
- p_dfs(root);
- int res = 0;
- for (int i = 0; i < n; i++) {
- for (int j = i + 1; j < n; j++) {
- t lca = tree.qmin(1, 0, id - 1, min(first[i], first[j]), max(first[i], first[j]));
- int nlca = lca.num;
- if (nlca == i) {
- int da = n - 1 - pcnt[j] - 1 - (ans[j] - ans[i] - 1);
- int db = pcnt[j];
- if (da == db) {
- res++;
- //printf("%d %d: %d %d %d\n", i, j, da, db, 1);
- }
- }
- else if (nlca == j) {
- int da = n - 1 - pcnt[i] - 1 - (ans[i] - ans[j] - 1);
- int db = pcnt[i];
- if (db == da) {
- res++;
- //printf("%d %d: %d %d %d\n", i, j, da, db, 2);
- }
- }
- else {
- int da;
- int db;
- if (ans[i] != ans[j]) {
- if (ans[i] > ans[j]) {
- int cur1 = i;
- while (parent[cur1] != nlca) {
- if (ans[i] - ans[parent[cur1]] < ans[j] + ans[parent[cur1]] - 2 * ans[nlca])
- cur1 = parent[cur1];
- else
- break;
- }
- da = pcnt[cur1];
- db = n - da - 2;
- }
- else {
- int cur1 = j;
- while (parent[cur1] != nlca) {
- if (ans[j] - ans[parent[cur1]] < ans[i] + ans[parent[cur1]] - 2 * ans[nlca])
- cur1 = parent[cur1];
- else
- break;
- }
- db = pcnt[cur1];
- da = n - db - 2;
- }
- }
- else {
- int cur1 = i;
- while (parent[cur1] != nlca) {
- cur1 = parent[cur1];
- }
- da = pcnt[cur1];
- cur1 = j;
- while (parent[cur1] != nlca) {
- cur1 = parent[cur1];
- }
- db = pcnt[cur1];
- }
- if (da == db) {
- res++;
- //printf("%d %d: %d %d %d\n", i, j, da, db, 2);
- }
- }
- }
- }
- printf("%d", res);
- return 0;
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement