Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- import collections
- class Environment:
- def __init__(self):
- self.parameter = {}
- self.sum = {}
- self.sigmoid = {}
- self.input = {}
- self.system_variable = {}
- self.exponent = {}
- self.product = {}
- self.sum_partial = {}
- self.product_partial = {}
- self.sigmoid_partial = {}
- self.exponent_partial = {}
- def _negate(self, nodes):
- return tuple(item.negate(self) for item in nodes)
- def make_sum(self, minus, constant, nodes):
- assert type(nodes) is tuple
- assert type(minus) is bool
- for node in nodes:
- assert isinstance(node, Product), node
- if len(nodes) == 0 and minus:
- result = self.make_sum(False, -constant, nodes)
- else:
- result = Sum(minus, constant, nodes)
- return self.sum.setdefault((minus, constant, nodes), result)
- def make_product(self, constant, nodes):
- assert constant != 0
- assert len(nodes) > 0
- return self.product.setdefault((constant, nodes),
- Product(constant, nodes))
- def make_exponent(self, node, exponent):
- assert exponent > 0
- return self.exponent.setdefault((node, exponent),
- Exponent(node, exponent))
- def make_system_variable(self, index):
- return self.system_variable.setdefault(index, SystemVariable(index))
- def make_parameter(self, name):
- return self.parameter.setdefault(name, Parameter(name))
- def make_input(self, index):
- return self.input.setdefault(index, Input(index))
- def add(self, left, right):
- left_sum = left.to_sum(self)
- right_sum = right.to_sum(self)
- constant = left_sum.at_zero() + right_sum.at_zero()
- if not left_sum.nodes:
- if right_sum.minus:
- constant = -constant
- return self.make_sum(right_sum.minus, constant, right_sum.nodes)
- if not right_sum.nodes:
- if left_sum.minus:
- constant = -constant
- return self.make_sum(left_sum.minus, constant, left_sum.nodes)
- if left_sum.minus:
- left_nodes = self._negate(left_sum.nodes)
- else:
- left_nodes = left_sum.nodes
- if right_sum.minus:
- right_nodes = self._negate(right_sum.nodes)
- else:
- right_nodes = right_sum.nodes
- left_length = len(left_sum.nodes)
- right_length = len(right_sum.nodes)
- node_list = []
- left_index = 0
- right_index = 0
- while left_index < left_length and right_index < right_length:
- left_node = left_nodes[left_index]
- right_node = right_nodes[right_index]
- comparison = left_node.fuzzy_cmp(right_node)
- if comparison == -1:
- node_list.append(left_node)
- left_index += 1
- elif comparison == 0:
- sum_constant = left_node.constant + right_node.constant
- if sum_constant != 0:
- node_list.append(self.make_product(
- sum_constant, left_node.nodes))
- left_index += 1
- right_index += 1
- else:
- node_list.append(right_node)
- right_index += 1
- if left_index == left_length:
- node_list.extend(right_nodes[right_index:])
- if right_index == right_length:
- node_list.extend(left_nodes[left_index:])
- if len(node_list) > 0 and node_list[0].constant < 0:
- minus = True
- node_list = self._negate(node_list)
- else:
- node_list = tuple(node_list)
- minus = False
- return self.make_sum(minus, constant, node_list)
- def multiply(self, left, right):
- left_sum = left.to_sum(self)
- right_sum = right.to_sum(self)
- minus = left_sum.minus != right_sum.minus
- first = self.make_sum(
- minus, left_sum.constant * right_sum.constant, ())
- outer_nodes = []
- if left_sum.constant:
- for node in right_sum.nodes:
- outer_node = self.make_product(
- left_sum.constant * node.constant, node.nodes)
- if left_sum.constant < 0:
- outer_node = outer_node.negate(self)
- outer_nodes.append(outer_node)
- outer = self.make_sum(
- (left_sum.constant < 0) != minus, 0, tuple(outer_nodes))
- inner_nodes = []
- if right_sum.constant:
- for node in left_sum.nodes:
- inner_node = self.make_product(
- right_sum.constant * node.constant, node.nodes)
- if right_sum.constant < 0:
- inner_node = inner_node.negate(self)
- inner_nodes.append(inner_node)
- inner = self.make_sum(
- (right_sum.constant < 0) != minus, 0, tuple(inner_nodes))
- last = self.number(0)
- for left_node in left_sum.nodes:
- for right_node in right_sum.nodes:
- combined_terms = self._combine_terms(
- left_node.nodes, right_node.nodes)
- assert len(combined_terms) > 0
- last = self.add(
- last,
- self.make_product(left_node.constant * right_node.constant,
- combined_terms))
- return self.add(self.add(self.add(first, outer), inner), last)
- def make_sigmoid(self, minus, node):
- if node.minus or (not node.nodes and node.constant < 0):
- result = self.make_sigmoid(not minus, node.negate(self))
- else:
- result = Sigmoid(minus, node)
- return self.sigmoid.setdefault((minus, node), result)
- def _combine_terms(self, left, right):
- left_length = len(left)
- right_length = len(right)
- assert left_length
- assert right_length
- node_list = []
- left_index = 0
- right_index = 0
- while left_index < left_length and right_index < right_length:
- left_node = left[left_index]
- right_node = right[right_index]
- comparison = left_node.node.cmp(right_node.node)
- if comparison == -1:
- node_list.append(left_node)
- left_index += 1
- elif comparison == 0:
- sum_constant = left_node.exponent + right_node.exponent
- assert sum_constant != 0
- node_list.append(self.make_exponent(left_node.node,
- sum_constant))
- left_index += 1
- right_index += 1
- else:
- node_list.append(right_node)
- right_index += 1
- if left_index == left_length:
- node_list.extend(right[right_index:])
- if right_index == right_length:
- node_list.extend(left[left_index:])
- return tuple(node_list)
- def number(self, number):
- return self.make_sum(False, number, ())
- def cmp(left, right):
- return (left > right) - (left < right)
- class State(collections.namedtuple('State', ['input variables parameters'])):
- def read_only(self):
- parameters = tuple(sorted(self.parameters.iteritems()))
- return self.input, self.variables, parameters
- class Node(tuple):
- minus = False
- def __eq__(self, other):
- return type(self) == type(other) and tuple(self) == tuple(other)
- def __hash__(self):
- return super().__hash__()
- def to_sum(self, environment):
- raise NotImplementedError
- def partial_derivative(self, variable, environment):
- raise NotImplementedError
- class Sum(collections.namedtuple('Sum', 'minus constant nodes'), Node):
- def to_sum(self, environment):
- return self
- def at_zero(self):
- if self.minus:
- return -self.constant
- else:
- return self.constant
- def negate(self, environment):
- return environment.make_sum(not self.minus, self.constant, self.nodes)
- def partial_derivative(self, variable, environment):
- result = environment.sum_partial.get((self, variable))
- if result is not None:
- return result
- result = environment.number(0)
- for node in self.nodes:
- result = environment.add(
- result, node.partial_derivative(variable, environment))
- if self.minus:
- result = result.negate(environment)
- return environment.sum_partial.setdefault((self, variable), result)
- def sum_cmp(self, other):
- cmp_result = cmp(self.constant, other.constant)
- if cmp_result:
- return cmp_result
- for self_node, other_node in zip(self.nodes, other.nodes):
- cmp_result = self_node.cmp(other_node)
- if cmp_result:
- return cmp_result
- cmp_result = cmp(len(self.nodes), len(other.nodes))
- if cmp_result:
- return cmp_result
- return cmp(self.minus, other.minus)
- class NonSum(Node):
- def to_product(self, environment):
- raise NotImplementedError
- def to_sum(self, environment):
- product = self.to_product(environment)
- if product.constant < 0:
- product = product.negate(environment)
- minus = True
- else:
- minus = False
- return environment.make_sum(minus, 0, (product,))
- class Product(collections.namedtuple('Product', 'constant nodes'), NonSum):
- def to_product(self, environment):
- return self
- def negate(self, environment):
- return environment.make_product(-self.constant, self.nodes)
- def cmp(self, other):
- cmp_result = self.fuzzy_cmp(other)
- if cmp_result:
- return cmp_result
- else:
- return cmp(self.constant, other.constant)
- def fuzzy_cmp(self, other):
- for self_node, other_node in zip(self.nodes, other.nodes):
- cmp_result = self_node[0].cmp(other_node[0])
- if cmp_result:
- return cmp_result
- cmp_result = cmp(self_node[1], other_node[1])
- if cmp_result:
- return cmp_result
- return cmp(len(self.nodes), len(other.nodes))
- def partial_derivative(self, variable, environment):
- result = environment.product_partial.get((self, variable))
- if result is not None:
- return result
- if len(self.nodes) == 1:
- result = self.nodes[0].partial_derivative(variable, environment)
- else:
- result = environment.number(0)
- for index, focus in enumerate(self.nodes):
- remainder = environment.make_product(
- 1,
- self.nodes[:index] + self.nodes[index + 1:])
- result = environment.add(
- result,
- environment.multiply(
- remainder,
- focus.partial_derivative(variable, environment)))
- result = environment.multiply(
- result,
- environment.make_sum(False, self.constant, ()))
- return environment.product_partial.setdefault((self, variable), result)
- class Exponent(collections.namedtuple('Exponent', 'node exponent'), NonSum):
- def to_product(self, environment):
- return environment.make_product(1, (self,))
- def partial_derivative(self, variable, environment):
- result = environment.exponent_partial.get((self, variable))
- if result is not None:
- return result
- if self.exponent == 1:
- result = self.node.partial_derivative(variable, environment)
- else:
- result = environment.multiply(
- environment.number(self.exponent),
- environment.multiply(
- self.node.partial_derivative(variable, environment),
- environment.make_exponent(self.node, self.exponent - 1)))
- return environment.exponent_partial.setdefault(
- (self, variable), result)
- class NonExponent(NonSum):
- priority = -1
- def to_exponent(self, environment):
- return environment.make_exponent(self, 1)
- def class_cmp(self, other):
- raise NotImplementedError
- def cmp(self, other):
- if self.priority == -1 or other.priority == -1:
- raise NotImplementedError
- cmp_result = cmp(self.priority, other.priority)
- if cmp_result:
- return cmp_result
- return self.class_cmp(other)
- class Atom(NonExponent):
- def to_product(self, environment):
- return environment.make_product(1, (self.to_exponent(environment),))
- def partial_derivative(self, variable, environment):
- return environment.number(0)
- class Indexed(collections.namedtuple('Indexed', 'index'), Atom):
- def class_cmp(self, other):
- return cmp(self.index, other.index)
- class Input(Indexed):
- pass
- class SystemVariable(Indexed):
- def partial_derivative(self, variable, environment):
- if self.index == variable:
- return environment.number(1)
- else:
- return environment.number(0)
- class Parameter(collections.namedtuple('Parameter', 'name'), Atom):
- def class_cmp(self, other):
- return cmp(self.name, other.name)
- class Sigmoid(collections.namedtuple('Sigmoid', 'minus node'), NonExponent):
- def negate(self, environment):
- return environment.make_sigmoid(not self.minus, self.node)
- def to_product(self, environment):
- if self.minus:
- return environment.make_product(
- -1, (self.negate(environment).to_exponent(),))
- else:
- return environment.make_product(1, (self.to_exponent(),))
- def class_cmp(self, other):
- cmp_result = self.node.sum_cmp(other.node)
- if cmp_result:
- return cmp_result
- return cmp(self.minus, other.minus)
- def partial_derivative(self, variable, environment):
- result = environment.sigmoid_partial.get((self, variable))
- if result is not None:
- return result
- result = environment.multiply(
- self.node.partial_derivative(variable),
- environment.make_sum(
- self.minus, 1, (environment.make_product(-1, ((self, 2),)),)))
- return environment.setdefault((self, variable), result)
- for index, cls in enumerate((Input, SystemVariable, Parameter, Sigmoid)):
- cls.priority = index
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement