Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from collections import defaultdict, Counter,deque
- from math import sqrt, log10, log, floor, factorial,gcd,ceil,log2
- from bisect import bisect_left, bisect_right,insort
- from itertools import permutations,combinations
- from heapq import heapify,heappop,heappush
- import sys, io, os
- input = sys.stdin.readline
- # input=io.BytesIO(os.read(0,os.fstat(0).st_size)).readline
- # sys.setrecursionlimit(10000)
- inf = float('inf')
- mod = 10 ** 9 + 7
- def yn(a): print("YES" if a else "NO")
- cl = lambda a, b: (a + b - 1) // b
- class DisjointSetUnion:
- def __init__(self, n):
- self.parent = list(range(n))
- self.size = [1] * n
- self.size_for_i = [[0, 0] for i in range(n)]
- self.num_sets = n
- def find(self, a):
- acopy = a
- while a != self.parent[a]:
- a = self.parent[a]
- while acopy != a:
- self.parent[acopy], acopy = a, self.parent[acopy]
- return a
- def union(self, a, b, val):
- a, b = self.find(a), self.find(b)
- if a != b:
- if self.size[a] < self.size[b]:
- a, b = b, a
- self.num_sets -= 1
- self.parent[b] = a
- self.size[a] += self.size[b]
- if self.size_for_i[a][0] != val:
- self.size_for_i[a][0] = val
- self.size_for_i[a][1] = 0
- if self.size_for_i[b][0] == val:
- self.size_for_i[a][1] += self.size_for_i[b][1]
- print(val, self.size_for_i, self.parent)
- return 1
- return 0
- def set_size(self, a):
- return self.size[self.find(a)]
- def __len__(self):
- return self.num_sets
- def main(N, edges, V):
- maxn = N + 1
- d = DisjointSetUnion(maxn)
- graph = [[] for i in range(N + 1)]
- for i, j in edges:
- graph[i].append(j)
- graph[j].append(i)
- values = defaultdict(list)
- for i in range(N):
- values[V[i]].append(i + 1)
- values_sorted = sorted(values)
- counter = 0
- for val in values_sorted:
- for j in values[val]:
- d.size_for_i[j] = [val, 1]
- for j in values[val]:
- for k in graph[j]:
- if V[k - 1] <= V[j - 1]:
- d.union(k, j, val)
- s = set()
- for j in values[val]:
- s.add(d.find(j))
- for roots in s:
- nodes_in_the_tree = d.size_for_i[roots][1]
- counter += (nodes_in_the_tree) * (nodes_in_the_tree - 1) // 2
- q=deque([1])
- visited=[0 for i in range(maxn)]
- visited[1]=1
- while q:
- s=q.popleft()
- for i in graph[s]:
- if visited[i]==0:
- visited[i]=1
- q.append(i)
- if V[s-1]==V[i-1]:
- counter-=1
- print(counter)
- # N=5
- # edges=[[1,2],[1,3],[3,4],[3,5]]
- # V=[2,3,1,2,3]
- # main(N,edges,V)
- #
- # N=4
- # edges=[[1,2],[1,3],[3,4]]
- # V=[2,2,3,3]
- # main(N,edges,V)
- t = int(input())
- for i in range(t):
- N = int(input())
- edges = []
- for i in range(N - 1):
- edges.append([int(i) for i in input().split()])
- V = [int(i) for i in input().split()]
- main(N, edges, V)
- """
- 1
- 10
- 1 2
- 1 3
- 1 4
- 1 5
- 3 6
- 3 7
- 4 8
- 4 9
- 5 10
- 1 2 3 3 2 1 2 3 4 5
- 1
- 4
- 1 2
- 2 3
- 3 4
- 2 1 1 2
- """
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement