Guest User

Untitled

a guest
Jun 18th, 2018
60
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 1.13 KB | None | 0 0
  1. class Module(object):
  2.  
  3. def __init__(self, graph_builder, parent_nodes):
  4. self._gb = graph_builder
  5. self._parents = parent_nodes
  6.  
  7.  
  8. class Multiply(Module):
  9.  
  10. def symbolic_backprop(self, grad_node):
  11. self.grad_node = grad_node
  12. a, b = self.parents
  13. return (
  14. self._gb.multiply(grad_node, b),
  15. self._gb.multiply(grad_node, a)
  16. )
  17.  
  18.  
  19. class Node(object):
  20.  
  21. def symbolic_backprop(self, grad_node):
  22.  
  23. parent_nodes = self.get_dependencies()
  24. parent_grads = self._module.symbolic_backprop(*self.output_grads)
  25.  
  26. for parent, grad in zip(parent_nodes, parent_grads):
  27. parent.symbolic_backprop(grad)
  28.  
  29.  
  30. class GraphBuilder(object):
  31.  
  32. def gradients(self, loss_node, target_node):
  33. loss_grad = self.constant(1.0)
  34. loss_node.symbolic_backprop(loss_grad)
  35. return target_node.grad_node
  36.  
  37.  
  38. if __name__ == '__main__':
  39.  
  40. gb = GraphBuilder()
  41. a = gb.Variable(2.0)
  42. b = gb.placeholder()
  43. loss = a * b
  44.  
  45. grad_a = gb.gradients(loss, a) # A Node() object
  46. new_a = gb.minus(a, grad_a)
  47. new_loss = new_a * b
  48.  
  49. # backproping through grad_a!
  50. grad_new_a = gb.gradients(new_loss, new_a) # Still a Node() object!
Add Comment
Please, Sign In to add comment