Advertisement
het_fadia

Untitled

Sep 21st, 2022 (edited)
31
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.26 KB | None | 0 0
  1. from collections import defaultdict, Counter,deque
  2. from math import sqrt, log10, log, floor, factorial,gcd,ceil,log2
  3. from bisect import bisect_left, bisect_right,insort
  4. from itertools import permutations,combinations
  5. from heapq import heapify,heappop,heappush
  6. import sys, io, os
  7. input = sys.stdin.readline
  8. # input=io.BytesIO(os.read(0,os.fstat(0).st_size)).readline
  9. # sys.setrecursionlimit(10000)
  10. inf = float('inf')
  11. mod = 10 ** 9 + 7
  12. def yn(a): print("YES" if a else "NO")
  13. cl = lambda a, b: (a + b - 1) // b
  14.  
  15.  
  16. class DisjointSetUnion:
  17. def __init__(self, n):
  18. self.parent = list(range(n))
  19. self.size = [1] * n
  20. self.size_for_i = [[0, 0] for i in range(n)]
  21. self.num_sets = n
  22.  
  23. def find(self, a):
  24. acopy = a
  25. while a != self.parent[a]:
  26. a = self.parent[a]
  27. while acopy != a:
  28. self.parent[acopy], acopy = a, self.parent[acopy]
  29. return a
  30.  
  31. def union(self, a, b, val):
  32. a, b = self.find(a), self.find(b)
  33. if a != b:
  34. if self.size[a] < self.size[b]:
  35. a, b = b, a
  36.  
  37. self.num_sets -= 1
  38. self.parent[b] = a
  39. self.size[a] += self.size[b]
  40. if self.size_for_i[a][0] != val:
  41. self.size_for_i[a][0] = val
  42. self.size_for_i[a][1] = 0
  43. if self.size_for_i[b][0] == val:
  44. self.size_for_i[a][1] += self.size_for_i[b][1]
  45. print(val, self.size_for_i, self.parent)
  46. return 1
  47. return 0
  48.  
  49. def set_size(self, a):
  50. return self.size[self.find(a)]
  51.  
  52. def __len__(self):
  53. return self.num_sets
  54.  
  55.  
  56. def main(N, edges, V):
  57. maxn = N + 1
  58. d = DisjointSetUnion(maxn)
  59. graph = [[] for i in range(N + 1)]
  60. for i, j in edges:
  61. graph[i].append(j)
  62. graph[j].append(i)
  63. values = defaultdict(list)
  64. for i in range(N):
  65. values[V[i]].append(i + 1)
  66. values_sorted = sorted(values)
  67.  
  68. counter = 0
  69. for val in values_sorted:
  70.  
  71. for j in values[val]:
  72. d.size_for_i[j] = [val, 1]
  73. for j in values[val]:
  74. for k in graph[j]:
  75. if V[k - 1] <= V[j - 1]:
  76. d.union(k, j, val)
  77. s = set()
  78. for j in values[val]:
  79. s.add(d.find(j))
  80. for roots in s:
  81. nodes_in_the_tree = d.size_for_i[roots][1]
  82. counter += (nodes_in_the_tree) * (nodes_in_the_tree - 1) // 2
  83. q=deque([1])
  84. visited=[0 for i in range(maxn)]
  85. visited[1]=1
  86. while q:
  87. s=q.popleft()
  88. for i in graph[s]:
  89. if visited[i]==0:
  90. visited[i]=1
  91. q.append(i)
  92. if V[s-1]==V[i-1]:
  93. counter-=1
  94. print(counter)
  95.  
  96.  
  97. # N=5
  98. # edges=[[1,2],[1,3],[3,4],[3,5]]
  99. # V=[2,3,1,2,3]
  100. # main(N,edges,V)
  101. #
  102. # N=4
  103. # edges=[[1,2],[1,3],[3,4]]
  104. # V=[2,2,3,3]
  105. # main(N,edges,V)
  106.  
  107.  
  108. t = int(input())
  109. for i in range(t):
  110. N = int(input())
  111. edges = []
  112. for i in range(N - 1):
  113. edges.append([int(i) for i in input().split()])
  114. V = [int(i) for i in input().split()]
  115. main(N, edges, V)
  116.  
  117. """
  118. 1
  119. 10
  120. 1 2
  121. 1 3
  122. 1 4
  123. 1 5
  124. 3 6
  125. 3 7
  126. 4 8
  127. 4 9
  128. 5 10
  129. 1 2 3 3 2 1 2 3 4 5
  130.  
  131. 1
  132. 4
  133. 1 2
  134. 2 3
  135. 3 4
  136. 2 1 1 2
  137. """
Tags: #google
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement