Advertisement
mwchase

LC:NN upload 1

Mar 4th, 2017
148
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 14.67 KB | None | 0 0
  1. import collections
  2.  
  3.  
  4. class Environment:
  5.  
  6.     def __init__(self):
  7.         self.parameter = {}
  8.         self.sum = {}
  9.         self.sigmoid = {}
  10.         self.input = {}
  11.         self.system_variable = {}
  12.         self.exponent = {}
  13.         self.product = {}
  14.  
  15.         self.sum_partial = {}
  16.         self.product_partial = {}
  17.         self.sigmoid_partial = {}
  18.         self.exponent_partial = {}
  19.  
  20.     def _negate(self, nodes):
  21.         return tuple(item.negate(self) for item in nodes)
  22.  
  23.     def make_sum(self, minus, constant, nodes):
  24.         assert type(nodes) is tuple
  25.         assert type(minus) is bool
  26.         for node in nodes:
  27.             assert isinstance(node, Product), node
  28.         if len(nodes) == 0 and minus:
  29.             result = self.make_sum(False, -constant, nodes)
  30.         else:
  31.             result = Sum(minus, constant, nodes)
  32.         return self.sum.setdefault((minus, constant, nodes), result)
  33.  
  34.     def make_product(self, constant, nodes):
  35.         assert constant != 0
  36.         assert len(nodes) > 0
  37.         return self.product.setdefault((constant, nodes),
  38.                                        Product(constant, nodes))
  39.  
  40.     def make_exponent(self, node, exponent):
  41.         assert exponent > 0
  42.         return self.exponent.setdefault((node, exponent),
  43.                                         Exponent(node, exponent))
  44.  
  45.     def make_system_variable(self, index):
  46.         return self.system_variable.setdefault(index, SystemVariable(index))
  47.  
  48.     def make_parameter(self, name):
  49.         return self.parameter.setdefault(name, Parameter(name))
  50.  
  51.     def make_input(self, index):
  52.         return self.input.setdefault(index, Input(index))
  53.  
  54.     def add(self, left, right):
  55.         left_sum = left.to_sum(self)
  56.         right_sum = right.to_sum(self)
  57.         constant = left_sum.at_zero() + right_sum.at_zero()
  58.         if not left_sum.nodes:
  59.             if right_sum.minus:
  60.                 constant = -constant
  61.             return self.make_sum(right_sum.minus, constant, right_sum.nodes)
  62.         if not right_sum.nodes:
  63.             if left_sum.minus:
  64.                 constant = -constant
  65.             return self.make_sum(left_sum.minus, constant, left_sum.nodes)
  66.         if left_sum.minus:
  67.             left_nodes = self._negate(left_sum.nodes)
  68.         else:
  69.             left_nodes = left_sum.nodes
  70.         if right_sum.minus:
  71.             right_nodes = self._negate(right_sum.nodes)
  72.         else:
  73.             right_nodes = right_sum.nodes
  74.         left_length = len(left_sum.nodes)
  75.         right_length = len(right_sum.nodes)
  76.         node_list = []
  77.         left_index = 0
  78.         right_index = 0
  79.         while left_index < left_length and right_index < right_length:
  80.             left_node = left_nodes[left_index]
  81.             right_node = right_nodes[right_index]
  82.             comparison = left_node.fuzzy_cmp(right_node)
  83.             if comparison == -1:
  84.                 node_list.append(left_node)
  85.                 left_index += 1
  86.             elif comparison == 0:
  87.                 sum_constant = left_node.constant + right_node.constant
  88.                 if sum_constant != 0:
  89.                     node_list.append(self.make_product(
  90.                         sum_constant, left_node.nodes))
  91.                 left_index += 1
  92.                 right_index += 1
  93.             else:
  94.                 node_list.append(right_node)
  95.                 right_index += 1
  96.             if left_index == left_length:
  97.                 node_list.extend(right_nodes[right_index:])
  98.             if right_index == right_length:
  99.                 node_list.extend(left_nodes[left_index:])
  100.         if len(node_list) > 0 and node_list[0].constant < 0:
  101.             minus = True
  102.             node_list = self._negate(node_list)
  103.         else:
  104.             node_list = tuple(node_list)
  105.             minus = False
  106.         return self.make_sum(minus, constant, node_list)
  107.  
  108.     def multiply(self, left, right):
  109.         left_sum = left.to_sum(self)
  110.         right_sum = right.to_sum(self)
  111.         minus = left_sum.minus != right_sum.minus
  112.         first = self.make_sum(
  113.             minus, left_sum.constant * right_sum.constant, ())
  114.         outer_nodes = []
  115.         if left_sum.constant:
  116.             for node in right_sum.nodes:
  117.                 outer_node = self.make_product(
  118.                     left_sum.constant * node.constant, node.nodes)
  119.                 if left_sum.constant < 0:
  120.                     outer_node = outer_node.negate(self)
  121.                 outer_nodes.append(outer_node)
  122.         outer = self.make_sum(
  123.             (left_sum.constant < 0) != minus, 0, tuple(outer_nodes))
  124.         inner_nodes = []
  125.         if right_sum.constant:
  126.             for node in left_sum.nodes:
  127.                 inner_node = self.make_product(
  128.                     right_sum.constant * node.constant, node.nodes)
  129.                 if right_sum.constant < 0:
  130.                     inner_node = inner_node.negate(self)
  131.                 inner_nodes.append(inner_node)
  132.         inner = self.make_sum(
  133.             (right_sum.constant < 0) != minus, 0, tuple(inner_nodes))
  134.         last = self.number(0)
  135.         for left_node in left_sum.nodes:
  136.             for right_node in right_sum.nodes:
  137.                 combined_terms = self._combine_terms(
  138.                     left_node.nodes, right_node.nodes)
  139.                 assert len(combined_terms) > 0
  140.                 last = self.add(
  141.                     last,
  142.                     self.make_product(left_node.constant * right_node.constant,
  143.                                       combined_terms))
  144.         return self.add(self.add(self.add(first, outer), inner), last)
  145.  
  146.     def make_sigmoid(self, minus, node):
  147.         if node.minus or (not node.nodes and node.constant < 0):
  148.             result = self.make_sigmoid(not minus, node.negate(self))
  149.         else:
  150.             result = Sigmoid(minus, node)
  151.         return self.sigmoid.setdefault((minus, node), result)
  152.  
  153.     def _combine_terms(self, left, right):
  154.         left_length = len(left)
  155.         right_length = len(right)
  156.         assert left_length
  157.         assert right_length
  158.         node_list = []
  159.         left_index = 0
  160.         right_index = 0
  161.         while left_index < left_length and right_index < right_length:
  162.             left_node = left[left_index]
  163.             right_node = right[right_index]
  164.             comparison = left_node.node.cmp(right_node.node)
  165.             if comparison == -1:
  166.                 node_list.append(left_node)
  167.                 left_index += 1
  168.             elif comparison == 0:
  169.                 sum_constant = left_node.exponent + right_node.exponent
  170.                 assert sum_constant != 0
  171.                 node_list.append(self.make_exponent(left_node.node,
  172.                                                     sum_constant))
  173.                 left_index += 1
  174.                 right_index += 1
  175.             else:
  176.                 node_list.append(right_node)
  177.                 right_index += 1
  178.             if left_index == left_length:
  179.                 node_list.extend(right[right_index:])
  180.             if right_index == right_length:
  181.                 node_list.extend(left[left_index:])
  182.         return tuple(node_list)
  183.  
  184.     def number(self, number):
  185.         return self.make_sum(False, number, ())
  186.  
  187.  
  188. def cmp(left, right):
  189.     return (left > right) - (left < right)
  190.  
  191.  
  192. class State(collections.namedtuple('State', ['input variables parameters'])):
  193.  
  194.     def read_only(self):
  195.         parameters = tuple(sorted(self.parameters.iteritems()))
  196.         return self.input, self.variables, parameters
  197.  
  198.  
  199. class Node(tuple):
  200.  
  201.     minus = False
  202.  
  203.     def __eq__(self, other):
  204.         return type(self) == type(other) and tuple(self) == tuple(other)
  205.  
  206.     def __hash__(self):
  207.         return super().__hash__()
  208.  
  209.     def to_sum(self, environment):
  210.         raise NotImplementedError
  211.  
  212.     def partial_derivative(self, variable, environment):
  213.         raise NotImplementedError
  214.  
  215.  
  216. class Sum(collections.namedtuple('Sum', 'minus constant nodes'), Node):
  217.  
  218.     def to_sum(self, environment):
  219.         return self
  220.  
  221.     def at_zero(self):
  222.         if self.minus:
  223.             return -self.constant
  224.         else:
  225.             return self.constant
  226.  
  227.     def negate(self, environment):
  228.         return environment.make_sum(not self.minus, self.constant, self.nodes)
  229.  
  230.     def partial_derivative(self, variable, environment):
  231.         result = environment.sum_partial.get((self, variable))
  232.         if result is not None:
  233.             return result
  234.         result = environment.number(0)
  235.         for node in self.nodes:
  236.             result = environment.add(
  237.                 result, node.partial_derivative(variable, environment))
  238.         if self.minus:
  239.             result = result.negate(environment)
  240.         return environment.sum_partial.setdefault((self, variable), result)
  241.  
  242.     def sum_cmp(self, other):
  243.         cmp_result = cmp(self.constant, other.constant)
  244.         if cmp_result:
  245.             return cmp_result
  246.         for self_node, other_node in zip(self.nodes, other.nodes):
  247.             cmp_result = self_node.cmp(other_node)
  248.             if cmp_result:
  249.                 return cmp_result
  250.         cmp_result = cmp(len(self.nodes), len(other.nodes))
  251.         if cmp_result:
  252.             return cmp_result
  253.         return cmp(self.minus, other.minus)
  254.  
  255.  
  256. class NonSum(Node):
  257.  
  258.     def to_product(self, environment):
  259.         raise NotImplementedError
  260.  
  261.     def to_sum(self, environment):
  262.         product = self.to_product(environment)
  263.         if product.constant < 0:
  264.             product = product.negate(environment)
  265.             minus = True
  266.         else:
  267.             minus = False
  268.         return environment.make_sum(minus, 0, (product,))
  269.  
  270.  
  271. class Product(collections.namedtuple('Product', 'constant nodes'), NonSum):
  272.  
  273.     def to_product(self, environment):
  274.         return self
  275.  
  276.     def negate(self, environment):
  277.         return environment.make_product(-self.constant, self.nodes)
  278.  
  279.     def cmp(self, other):
  280.         cmp_result = self.fuzzy_cmp(other)
  281.         if cmp_result:
  282.             return cmp_result
  283.         else:
  284.             return cmp(self.constant, other.constant)
  285.  
  286.     def fuzzy_cmp(self, other):
  287.         for self_node, other_node in zip(self.nodes, other.nodes):
  288.             cmp_result = self_node[0].cmp(other_node[0])
  289.             if cmp_result:
  290.                 return cmp_result
  291.             cmp_result = cmp(self_node[1], other_node[1])
  292.             if cmp_result:
  293.                 return cmp_result
  294.         return cmp(len(self.nodes), len(other.nodes))
  295.  
  296.     def partial_derivative(self, variable, environment):
  297.         result = environment.product_partial.get((self, variable))
  298.         if result is not None:
  299.             return result
  300.         if len(self.nodes) == 1:
  301.             result = self.nodes[0].partial_derivative(variable, environment)
  302.         else:
  303.             result = environment.number(0)
  304.             for index, focus in enumerate(self.nodes):
  305.                 remainder = environment.make_product(
  306.                     1,
  307.                     self.nodes[:index] + self.nodes[index + 1:])
  308.                 result = environment.add(
  309.                     result,
  310.                     environment.multiply(
  311.                         remainder,
  312.                         focus.partial_derivative(variable, environment)))
  313.         result = environment.multiply(
  314.             result,
  315.             environment.make_sum(False, self.constant, ()))
  316.         return environment.product_partial.setdefault((self, variable), result)
  317.  
  318.  
  319. class Exponent(collections.namedtuple('Exponent', 'node exponent'), NonSum):
  320.  
  321.     def to_product(self, environment):
  322.         return environment.make_product(1, (self,))
  323.  
  324.     def partial_derivative(self, variable, environment):
  325.         result = environment.exponent_partial.get((self, variable))
  326.         if result is not None:
  327.             return result
  328.         if self.exponent == 1:
  329.             result = self.node.partial_derivative(variable, environment)
  330.         else:
  331.             result = environment.multiply(
  332.                 environment.number(self.exponent),
  333.                 environment.multiply(
  334.                     self.node.partial_derivative(variable, environment),
  335.                     environment.make_exponent(self.node, self.exponent - 1)))
  336.         return environment.exponent_partial.setdefault(
  337.             (self, variable), result)
  338.  
  339.  
  340. class NonExponent(NonSum):
  341.  
  342.     priority = -1
  343.  
  344.     def to_exponent(self, environment):
  345.         return environment.make_exponent(self, 1)
  346.  
  347.     def class_cmp(self, other):
  348.         raise NotImplementedError
  349.  
  350.     def cmp(self, other):
  351.         if self.priority == -1 or other.priority == -1:
  352.             raise NotImplementedError
  353.         cmp_result = cmp(self.priority, other.priority)
  354.         if cmp_result:
  355.             return cmp_result
  356.         return self.class_cmp(other)
  357.  
  358.  
  359. class Atom(NonExponent):
  360.  
  361.     def to_product(self, environment):
  362.         return environment.make_product(1, (self.to_exponent(environment),))
  363.  
  364.     def partial_derivative(self, variable, environment):
  365.         return environment.number(0)
  366.  
  367.  
  368. class Indexed(collections.namedtuple('Indexed', 'index'), Atom):
  369.  
  370.     def class_cmp(self, other):
  371.         return cmp(self.index, other.index)
  372.  
  373.  
  374. class Input(Indexed):
  375.  
  376.     pass
  377.  
  378.  
  379. class SystemVariable(Indexed):
  380.  
  381.     def partial_derivative(self, variable, environment):
  382.         if self.index == variable:
  383.             return environment.number(1)
  384.         else:
  385.             return environment.number(0)
  386.  
  387.  
  388. class Parameter(collections.namedtuple('Parameter', 'name'), Atom):
  389.  
  390.     def class_cmp(self, other):
  391.         return cmp(self.name, other.name)
  392.  
  393.  
  394. class Sigmoid(collections.namedtuple('Sigmoid', 'minus node'), NonExponent):
  395.  
  396.     def negate(self, environment):
  397.         return environment.make_sigmoid(not self.minus, self.node)
  398.  
  399.     def to_product(self, environment):
  400.         if self.minus:
  401.             return environment.make_product(
  402.                 -1, (self.negate(environment).to_exponent(),))
  403.         else:
  404.             return environment.make_product(1, (self.to_exponent(),))
  405.  
  406.     def class_cmp(self, other):
  407.         cmp_result = self.node.sum_cmp(other.node)
  408.         if cmp_result:
  409.             return cmp_result
  410.         return cmp(self.minus, other.minus)
  411.  
  412.     def partial_derivative(self, variable, environment):
  413.         result = environment.sigmoid_partial.get((self, variable))
  414.         if result is not None:
  415.             return result
  416.         result = environment.multiply(
  417.             self.node.partial_derivative(variable),
  418.             environment.make_sum(
  419.                 self.minus, 1, (environment.make_product(-1, ((self, 2),)),)))
  420.         return environment.setdefault((self, variable), result)
  421.  
  422.  
  423. for index, cls in enumerate((Input, SystemVariable, Parameter, Sigmoid)):
  424.     cls.priority = index
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement