Guest User

Untitled

a guest
Jan 17th, 2018
100
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.02 KB | None | 0 0
  1. from itertools import ifilter
  2. from collections import defaultdict as dd
  3.  
  4.  
  5. class GEException(Exception)
  6.  
  7.  
  8. class CircularDependencyException(GEException):
  9. pass
  10.  
  11.  
  12. class NodeNotReadyException(GEException):
  13. pass
  14.  
  15.  
  16. class NoValue(object):
  17. pass
  18.  
  19.  
  20. no_value = NoValue()
  21.  
  22.  
  23. class Input(object):
  24. pass
  25.  
  26.  
  27. class SimpleInput(Input):
  28. def __init__(self, value):
  29. self._value = value
  30.  
  31. @property
  32. def value(self):
  33. return self._value
  34.  
  35. def __str__(self):
  36. return "<SimpleInput value=%s >" % self._value
  37.  
  38. __repr__ = __str__
  39.  
  40.  
  41. class NodeOutput(Input):
  42. def __init__(self, name):
  43. self._name = name
  44.  
  45. @property
  46. def name(self):
  47. return self._name
  48.  
  49. def _set_node(self, node):
  50. self._node = node
  51.  
  52. @property
  53. def value(self):
  54. return self._node._value
  55.  
  56. def __str__(self):
  57. return "<NodeOutput value=%s >" % (self._node.value if hasattr(self, "_node") else None)
  58.  
  59. __repr__ = __str__
  60.  
  61.  
  62. class Node(object):
  63. def __init__(self, name, f, **kwargs):
  64. # Checking if the kwargs are well-formed
  65. assert set(f.__code__.co_varnames) == set(kwargs.keys())
  66. for value in kwargs.itervalues():
  67. assert issubclass(value.__class__, Input)
  68.  
  69. self._f = f
  70. self._kwargs = kwargs
  71. self._name = name
  72. self._value = no_value
  73.  
  74. @property
  75. def ready(self):
  76. return self._value == no_value
  77.  
  78. @property
  79. def name(self):
  80. return self._name
  81.  
  82. @property
  83. def f(self):
  84. return self._f
  85.  
  86. @property
  87. def kwargs(self):
  88. return self._kwargs
  89.  
  90.  
  91. @property
  92. def value(self):
  93. return self._value
  94.  
  95. def execute(self):
  96. self._value = self._f(**{
  97. key: inp.value for key, inp in self._kwargs.iteritems()
  98. })
  99.  
  100. def __str__(self):
  101. return "<Node name=%s, value=%s>" % (self._name, self._value)
  102.  
  103. __repr__ = __str__
  104.  
  105.  
  106. class GraphExecutor(object):
  107. def __init__(self, *nodes):
  108. self._nodes = nodes
  109. self._value = no_value
  110.  
  111. @property
  112. def ready(self):
  113. return self._value != no_value
  114.  
  115. @property
  116. def value(self):
  117. return self._value
  118.  
  119. def execute(self):
  120. if self.ready:
  121. return
  122.  
  123. name_to_node = {node.name: node for node in self._nodes}
  124.  
  125. forward, backward = dd(set), dd(set)
  126. for node in self._nodes:
  127. cname = node.name
  128. for inp in node.kwargs.itervalues():
  129. if isinstance(inp, NodeOutput):
  130. pname = inp.name
  131. assert pname in name_to_node, "Referring to unknown node with name: %s" % pname
  132.  
  133. inp._set_node(name_to_node[pname])
  134. forward[pname].add(cname)
  135. backward[cname].add(pname)
  136.  
  137.  
  138. stack = [name for name in name_to_node.iterkeys() if not backward[name]]
  139. sinks = [name for name in name_to_node.iterkeys() if not forward[name]]
  140. assert len(sinks) == 1, "There must be exactly one sink node"
  141. sink = sinks[0]
  142.  
  143. if not stack:
  144. raise CircularDependencyException()
  145.  
  146. was = set(stack)
  147. while stack:
  148. nstack = []
  149. for elem in stack:
  150. name_to_node[elem].execute()
  151. for nxt in forward[elem]:
  152. backward[nxt].remove(elem)
  153. if not backward[nxt]:
  154. if nxt in was:
  155. raise CircularDependencyError()
  156. else:
  157. was.add(nxt)
  158. nstack.append(nxt)
  159.  
  160. stack = nstack
  161.  
  162. self._value = name_to_node[sink].value
  163.  
  164.  
  165.  
  166.  
  167.  
  168. def f(a, b, c):
  169. return a + b + c
  170.  
  171.  
  172. def g(c):
  173. return c * 2
  174.  
  175.  
  176. def h(d, e):
  177. return d - e
  178.  
  179.  
  180. if __name__ == "__main__":
  181. job_graph = [
  182. Node("f", f, a=SimpleInput(12), b=SimpleInput(13), c=SimpleInput(14)),
  183. Node("g", g, c=SimpleInput(15)),
  184. Node("h", h, d=NodeOutput("f"), e=NodeOutput("g"))
  185. ]
  186. ge = GraphExecutor(*job_graph)
  187. ge.execute()
  188. print ge.value
Add Comment
Please, Sign In to add comment