Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- """Implement a recursive function to find all nodes between A and B on a DAG."""
- from typing import Iterable
- from typing import Optional
- from typing import Set
- from typing import List
- VISITED: Set["Node"] = set()
- ON_THE_PATH: Set["Node"] = set()
- class Node:
- """A simulated Luigi task node."""
- def __init__(self, name: str) -> None:
- self.task_name = name
- self._requires: Set["Node"] = set()
- def __eq__(self, other: object) -> bool:
- return (
- isinstance(other, self.__class__)
- and self.task_name == other.task_name
- )
- def __hash__(self) -> int:
- return hash(
- "+".join([self.task_name] + [r.task_name for r in self._requires])
- )
- def __repr__(self) -> str:
- return self.task_name
- def requires(self) -> List["Node"]:
- """Return the dependancies of the node."""
- return list(self._requires)
- def add_requires(self, requires: Iterable["Node"]) -> None:
- """Add a denpendaency."""
- for require in requires:
- self._requires.add(require)
- def find_paths(node_a: Node, node_b: Optional[Node] = None) -> bool:
- """Check node_a is on the path between node_a and node_b.
- If node_b is not provided, check node_a to any upstream
- leaf.
- """
- if node_a in VISITED:
- return node_a in ON_THE_PATH
- VISITED.add(node_a)
- if node_a == node_b:
- is_on_path = True
- elif not node_a.requires():
- is_on_path = not node_b
- else:
- upstreams = [find_paths(node, node_b) for node in node_a.requires()]
- is_on_path = any(upstreams)
- if on_path:
- ON_THE_PATH.add(node_a)
- return on_path
- def main() -> None:
- node_9 = Node(name="9")
- node_8 = Node(name="8")
- node_7 = Node(name="7")
- node_6 = Node(name="6")
- node_5 = Node(name="5")
- node_4 = Node(name="4")
- node_3 = Node(name="3")
- node_2 = Node(name="2")
- node_1 = Node(name="1")
- node_8.add_requires([node_7])
- node_4.add_requires([node_9, node_8, node_7])
- node_3.add_requires([node_7, node_6])
- node_2.add_requires([node_6])
- node_1.add_requires([node_5, node_4, node_3, node_2])
- def test_from_a_to_b() -> None:
- global VISITED
- global ON_THE_PATH
- VISITED = set()
- ON_THE_PATH = set()
- find_paths(node_a=node_1, node_b=node_7)
- nodes_on_the_path = {n.task_name for n in ON_THE_PATH}
- print("Nodes marked for delete:", ", ".join(sorted(nodes_on_the_path)))
- assert nodes_on_the_path == set(["1", "3", "4", "7", "8"])
- def test_from_a_to_all() -> None:
- global VISITED
- global ON_THE_PATH
- VISITED = set()
- ON_THE_PATH = set()
- find_paths(node_a=node_1)
- nodes_on_the_path = {n.task_name for n in ON_THE_PATH}
- print("Nodes marked for delete:", ", ".join(sorted(nodes_on_the_path)))
- assert nodes_on_the_path == set(["1", "2", "3", "4", "5", "6", "7", "8", "9"])
- test_from_a_to_b()
- test_from_a_to_all()
- if __name__ == "__main__":
- main()
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement