Advertisement
Guest User

Untitled

a guest
Jun 16th, 2019
105
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 3.03 KB | None | 0 0
  1. """Implement a recursive function to find all nodes between A and B on a DAG."""
  2.  
  3. from typing import Iterable
  4. from typing import Optional
  5. from typing import Set
  6. from typing import List
  7.  
  8.  
  9. VISITED: Set["Node"] = set()
  10. ON_THE_PATH: Set["Node"] = set()
  11.  
  12.  
  13. class Node:
  14. """A simulated Luigi task node."""
  15.  
  16. def __init__(self, name: str) -> None:
  17. self.task_name = name
  18. self._requires: Set["Node"] = set()
  19.  
  20. def __eq__(self, other: object) -> bool:
  21. return (
  22. isinstance(other, self.__class__)
  23. and self.task_name == other.task_name
  24. )
  25.  
  26. def __hash__(self) -> int:
  27. return hash(
  28. "+".join([self.task_name] + [r.task_name for r in self._requires])
  29. )
  30.  
  31. def __repr__(self) -> str:
  32. return self.task_name
  33.  
  34. def requires(self) -> List["Node"]:
  35. """Return the dependancies of the node."""
  36. return list(self._requires)
  37.  
  38. def add_requires(self, requires: Iterable["Node"]) -> None:
  39. """Add a denpendaency."""
  40. for require in requires:
  41. self._requires.add(require)
  42.  
  43.  
  44. def find_paths(node_a: Node, node_b: Optional[Node] = None) -> bool:
  45. """Check node_a is on the path between node_a and node_b.
  46.  
  47. If node_b is not provided, check node_a to any upstream
  48. leaf.
  49. """
  50. if node_a in VISITED:
  51. return node_a in ON_THE_PATH
  52.  
  53. VISITED.add(node_a)
  54.  
  55. if node_a == node_b:
  56. is_on_path = True
  57. elif not node_a.requires():
  58. is_on_path = not node_b
  59. else:
  60. upstreams = [find_paths(node, node_b) for node in node_a.requires()]
  61. is_on_path = any(upstreams)
  62.  
  63. if on_path:
  64. ON_THE_PATH.add(node_a)
  65.  
  66. return on_path
  67.  
  68.  
  69. def main() -> None:
  70. node_9 = Node(name="9")
  71. node_8 = Node(name="8")
  72. node_7 = Node(name="7")
  73. node_6 = Node(name="6")
  74. node_5 = Node(name="5")
  75. node_4 = Node(name="4")
  76. node_3 = Node(name="3")
  77. node_2 = Node(name="2")
  78. node_1 = Node(name="1")
  79.  
  80. node_8.add_requires([node_7])
  81. node_4.add_requires([node_9, node_8, node_7])
  82. node_3.add_requires([node_7, node_6])
  83. node_2.add_requires([node_6])
  84. node_1.add_requires([node_5, node_4, node_3, node_2])
  85.  
  86. def test_from_a_to_b() -> None:
  87. global VISITED
  88. global ON_THE_PATH
  89. VISITED = set()
  90. ON_THE_PATH = set()
  91.  
  92. find_paths(node_a=node_1, node_b=node_7)
  93. nodes_on_the_path = {n.task_name for n in ON_THE_PATH}
  94. print("Nodes marked for delete:", ", ".join(sorted(nodes_on_the_path)))
  95. assert nodes_on_the_path == set(["1", "3", "4", "7", "8"])
  96.  
  97. def test_from_a_to_all() -> None:
  98. global VISITED
  99. global ON_THE_PATH
  100. VISITED = set()
  101. ON_THE_PATH = set()
  102.  
  103. find_paths(node_a=node_1)
  104. nodes_on_the_path = {n.task_name for n in ON_THE_PATH}
  105. print("Nodes marked for delete:", ", ".join(sorted(nodes_on_the_path)))
  106. assert nodes_on_the_path == set(["1", "2", "3", "4", "5", "6", "7", "8", "9"])
  107.  
  108. test_from_a_to_b()
  109. test_from_a_to_all()
  110.  
  111.  
  112. if __name__ == "__main__":
  113. main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement