Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- # coding: utf-8
- # In[37]:
- #opertaion class
- class Operation():
- def __init__(self,input_nodes=[]):
- self.input_nodes = input_nodes
- self.output_nodes = []
- for node in input_nodes:
- node.output_nodes.append(self)
- _default_graph.operations.append(self)
- def compute(self):
- pass
- # In[38]:
- #class add extending class operation
- class add(Operation):
- def __init__(self,x,y):
- super().__init__([x,y])
- def compute(self,x_var,y_var):
- self.inputs = [x_var,y_var]
- return x_var + y_var
- # In[39]:
- #class multiply extending class operation
- class multiply(Operation):
- def __init__(self,x,y):
- super().__init__([x,y])
- def compute(self,x_var,y_var):
- self.inputs = [x_var,y_var]
- return x_var * y_var
- # In[40]:
- #class matmul extending class operation
- class matmul(Operation):
- def __init__(self,x,y):
- super().__init__([x,y])
- def compute(self,x_var,y_var):
- self.inputs = [x_var,y_var]
- return x_var.dot(y_var)
- # In[41]:
- class Placeholder():
- def __init__(self):
- self.output_nodes = []
- _default_graph.placeholders.append(self)
- # In[42]:
- class Variable():
- def __init__(self,initial_value=None):
- self.value = initial_value
- self.output_nodes = []
- _default_graph.variables.append(self)
- # In[43]:
- class Graph():
- def __init__(self):
- self.operations = []
- self.placeholders = []
- self.variables = []
- def set_as_default(self):
- global _default_graph
- _default_graph = self
- # In[44]:
- def traverse_postorder(operation):
- nodes_postorder = []
- def recurse(node):
- if isinstance(node, Operation):
- for input_node in node.input_node:
- recurse(input_node)
- node_postorder.append(node)
- recurse(operation)
- return nodes_postorder
- # In[45]:
- class Session():
- def run(self,operation,feed_dict={}):
- nodes_postorder = traverse_postorder(operation)
- for node in nodes_postorder:
- if type(node) == Placeholder:
- node.output = feed_dict[node]
- elif type(node) == Variable:
- node.output = node.value
- else:#operation
- node.inputs = [input_node.output for input_node in node.input_node]
- node.output = node.compute(*node.input)
- if type(node.output) == list:
- node.output = np.array(node.output)
- return operation.output
- # In[46]:
- g= Graph()
- # In[47]:
- g.set_as_default()
- # In[48]:
- A = Variable(10)
- # In[49]:
- b = Variable(1)
- # In[50]:
- x = Placeholder()
- # In[51]:
- y = multiply(A,x)
- # In[52]:
- sess = Session()
- # In[53]:
- z = add(y,b)
- # In[55]:
- import numpy as np
- # In[56]:
- result = sess.run(operation=z, feed_dict={x:10})
Add Comment
Please, Sign In to add comment