Advertisement
kosievdmerwe

Untitled

Sep 4th, 2021
75
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.22 KB | None | 0 0
  1. class Tree:
  2.     def __init__(self, val):
  3.         self.val = val
  4.         self.parent = None
  5.         self.children = []
  6.        
  7.         self.dist_descendants = 0
  8.         self.num_descendants = 1
  9.        
  10.     def precalc(self) -> None:
  11.         for child in self.children:
  12.             child.precalc()
  13.             self.num_descendants += child.num_descendants
  14.             self.dist_descendants += (
  15.                 child.dist_descendants + child.num_descendants
  16.             )
  17.        
  18.  
  19. class Solution:
  20.     def sumOfDistancesInTree(self, n: int, edges: List[List[int]]) -> List[int]:
  21.         edge_map = defaultdict(set)
  22.         for e in edges:
  23.             edge_map[e[0]].add(e[1])
  24.             edge_map[e[1]].add(e[0])
  25.         print(edge_map)
  26.        
  27.         trees = [Tree(i) for i in range(n)]
  28.         root = trees[0]
  29.        
  30.         def construct_tree(node) -> None:
  31.             for child_idx in edge_map[node.val]:
  32.                 if child_idx == node.parent:
  33.                     continue
  34.                 child = trees[child_idx]
  35.                 node.children.append(child)
  36.                 child.parent = node.val
  37.                 construct_tree(child)
  38.         construct_tree(root)
  39.         root.precalc()
  40.        
  41.         ans = [-1] * n
  42.         def calc_ans(
  43.             node: Tree,
  44.             non_node_cnt: int = 0,
  45.             non_node_dist: int = 0,
  46.         ) -> None:
  47.             ans[node.val] = (
  48.                 node.dist_descendants
  49.                 + non_node_dist
  50.             )
  51.             for child in node.children:
  52.                 # Sibling includes the parent
  53.                 sibling_cnt = node.num_descendants - child.num_descendants
  54.                 sibling_dist = (
  55.                     node.dist_descendants
  56.                     # remove the child subtree
  57.                     - (child.dist_descendants + child.num_descendants)
  58.                     # add the extra edge for each sibling
  59.                     + sibling_cnt
  60.                 )
  61.                
  62.                 calc_ans(
  63.                     child,
  64.                     non_node_cnt + sibling_cnt,
  65.                     (non_node_dist + non_node_cnt) + sibling_dist,
  66.                 )
  67.         calc_ans(root)
  68.            
  69.         return ans
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement