Advertisement
naman1601

cut-the-tree-iterative.py

Nov 2nd, 2020
1,920
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 2.13 KB | None | 0 0
  1. node_sums = {}
  2. edge_dict = {}
  3. seen = set()
  4.  
  5.  
  6.  
  7. def get_forward_branches(node):
  8.  
  9.     node_edges = edge_dict[node]
  10.     retlist = []
  11.  
  12.     for target_node in node_edges:
  13.         if target_node not in seen:
  14.             retlist.append(target_node)
  15.  
  16.     return retlist
  17.  
  18.  
  19.  
  20. def has_forward_branches(node):
  21.  
  22.     for sub_node in edge_dict[node]:
  23.         if sub_node not in seen:
  24.             return True
  25.  
  26.     return False
  27.  
  28.  
  29.  
  30. def get_sum(node, node_values):
  31.  
  32.     '''if node not in seen:
  33.         seen.add(node)'''
  34.  
  35.     if node in node_sums:
  36.         return node_sums[node]
  37.  
  38.     retval = node_values[node]
  39.  
  40.     nodes_to_visit = []
  41.  
  42.     if has_forward_branches(node):
  43.         nodes_to_visit.extend(get_forward_branches(node))
  44.  
  45.         for target_node in nodes_to_visit:
  46.  
  47.             '''if target_node not in seen:
  48.                 seen.add(target_node)'''
  49.  
  50.             if target_node in node_sums:
  51.                 retval += node_sums[target_node]
  52.  
  53.             else:
  54.                 retval += node_values[target_node]
  55.  
  56.                 if has_forward_branches(target_node):
  57.                     nodes_to_visit.extend(get_forward_branches(target_node))
  58.                 else:
  59.                     node_sums[target_node] = node_values[target_node]
  60.  
  61.  
  62.     node_sums[node] = retval
  63.     #print(nodes_to_visit)
  64.     #print(node_sums)
  65.     return retval
  66.  
  67.  
  68.  
  69. def cutTheTree(node_values, edges):
  70.  
  71.     node_values_sum = sum(node_values)
  72.     half_sum = node_values_sum / 2
  73.     req_sum = 0
  74.     global edge_dict
  75.     global node_sums
  76.  
  77.    
  78.     for edge in edges:
  79.  
  80.         if edge[0] - 1 in edge_dict:
  81.             edge_dict[edge[0] - 1].append(edge[1] - 1)
  82.  
  83.         else:
  84.             edge_dict[edge[0] - 1] = [edge[1] - 1]
  85.  
  86.  
  87.         if edge[1] - 1 in edge_dict:
  88.             edge_dict[edge[1] - 1].append(edge[0] - 1)
  89.  
  90.         else:
  91.             edge_dict[edge[1] - 1] = [edge[0] - 1]
  92.  
  93.  
  94.     nodes_to_visit = [0]
  95.  
  96.     for node in nodes_to_visit:
  97.         seen.add(node)
  98.         if has_forward_branches(node):
  99.             nodes_to_visit.extend(get_forward_branches(node))
  100.  
  101.     nodes_to_visit.reverse()
  102.     seen.clear()
  103.     #print(nodes_to_visit)
  104.  
  105.     for node in range(len(node_values)):
  106.         seen.add(node)
  107.    
  108.     for node in nodes_to_visit:
  109.         seen.remove(node)
  110.         current_node_sum = get_sum(node, node_values)
  111.  
  112.         if abs(half_sum - current_node_sum) < abs(half_sum - req_sum):
  113.             req_sum = current_node_sum
  114.  
  115.    
  116.     return int((2 * abs(half_sum - req_sum)))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement