Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- class Module(object):
- def __init__(self, graph_builder, parent_nodes):
- self._gb = graph_builder
- self._parents = parent_nodes
- class Multiply(Module):
- def symbolic_backprop(self, grad_node):
- self.grad_node = grad_node
- a, b = self.parents
- return (
- self._gb.multiply(grad_node, b),
- self._gb.multiply(grad_node, a)
- )
- class Node(object):
- def symbolic_backprop(self, grad_node):
- parent_nodes = self.get_dependencies()
- parent_grads = self._module.symbolic_backprop(*self.output_grads)
- for parent, grad in zip(parent_nodes, parent_grads):
- parent.symbolic_backprop(grad)
- class GraphBuilder(object):
- def gradients(self, loss_node, target_node):
- loss_grad = self.constant(1.0)
- loss_node.symbolic_backprop(loss_grad)
- return target_node.grad_node
- if __name__ == '__main__':
- gb = GraphBuilder()
- a = gb.Variable(2.0)
- b = gb.placeholder()
- loss = a * b
- grad_a = gb.gradients(loss, a) # A Node() object
- new_a = gb.minus(a, grad_a)
- new_loss = new_a * b
- # backproping through grad_a!
- grad_new_a = gb.gradients(new_loss, new_a) # Still a Node() object!
Add Comment
Please, Sign In to add comment