gt22

Untitled

Sep 21st, 2020
943
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. from __future__ import annotations
  2. from abc import ABC, abstractmethod
  3. from typing import Dict
  4. import numpy as np
  5. from math import log
  6.  
  7.  
  8. def literal_convert(x):
  9.     if isinstance(x, Expression):
  10.         return x
  11.     if isinstance(x, int):
  12.         return Constant(float(x))
  13.     if isinstance(x, float):
  14.         return Constant(x)
  15.     if isinstance(x, str):
  16.         try:
  17.             return Constant(float(x))
  18.         except ValueError:
  19.             return Variable(x)
  20.     raise ValueError(f"Can't convert ({x})@{type(x)} to expression")
  21.  
  22.  
  23. class Expression(ABC):
  24.  
  25.     subexprs: Dict[str, Expression]
  26.  
  27.     @abstractmethod
  28.     def value(self, var: Dict[str, float]) -> float:
  29.         raise NotImplemented
  30.  
  31.     @abstractmethod
  32.     def _proper_deriv(self, sub: str) -> Expression:
  33.         raise NotImplemented
  34.  
  35.     def deriv(self, by: str):
  36.         d = Constant(0)
  37.         for name, expr in self.subexprs.items():
  38.             d += expr.deriv(by) * self._proper_deriv(name)
  39.         return d
  40.  
  41.     def _const_reduce(self):
  42.         try:
  43.             return Constant(self.value({}))
  44.         except ValueError:
  45.             return self
  46.  
  47.     def reduce(self):
  48.         return self._const_reduce()
  49.  
  50.     def __eq__(self, other):
  51.         return type(self) is type(other) and all(other.subexprs[name] == e for name, e in self.subexprs.items())
  52.  
  53.     def __add__(self, other):
  54.         return Plus(self, literal_convert(other))
  55.  
  56.     def __mul__(self, other):
  57.         return Mul(self, literal_convert(other))
  58.  
  59.     def __neg__(self):
  60.         return self * -1
  61.  
  62.     def __sub__(self, other):
  63.         return self + (-other)
  64.  
  65.     def __pow__(self, power, modulo=None):
  66.         if modulo is not None:
  67.             raise NotImplemented
  68.         return Pow(self, literal_convert(power))
  69.  
  70.     def __truediv__(self, other):
  71.         return self * (other ** -1)
  72.  
  73.  
  74. class Constant(Expression):
  75.  
  76.     val: float
  77.  
  78.     def __init__(self, val: float):
  79.         self.val = val
  80.         self.subexprs = {}
  81.  
  82.     def value(self, var: Dict[str, float]) -> float:
  83.         return self.val
  84.  
  85.     def _proper_deriv(self, sub: str) -> Expression:
  86.         raise NotImplemented
  87.  
  88.     def deriv(self, var: str) -> Expression:
  89.         return Constant(0)
  90.  
  91.     def __str__(self):
  92.         return str(self.val)
  93.  
  94.     def __eq__(self, other):
  95.         return super(Constant, self).__eq__(other) and self.val == other.val
  96.  
  97.  
  98. class Variable(Expression):
  99.  
  100.     name: str
  101.  
  102.     def __init__(self, name: str):
  103.         self.name = name
  104.         self.subexprs = {}
  105.  
  106.     def value(self, var: Dict[str, float]) -> float:
  107.         if self.name not in var:
  108.             raise ValueError(f"Variable {self.name} not found")
  109.         return var[self.name]
  110.  
  111.     def _proper_deriv(self, sub: str) -> Expression:
  112.         raise NotImplemented
  113.  
  114.     def deriv(self, var: str) -> Expression:
  115.         return Constant(1 if var == self.name else 0)
  116.  
  117.     def __str__(self):
  118.         return self.name
  119.  
  120.     def __eq__(self, other):
  121.         return super(Variable, self).__eq__(other) and self.name == other.name
  122.  
  123.  
  124. class Plus(Expression):
  125.  
  126.     left: Expression
  127.     right: Expression
  128.  
  129.     def __init__(self, left: Expression, right: Expression):
  130.         self.subexprs = {
  131.             'left': left,
  132.             'right': right
  133.         }
  134.         self.left = left
  135.         self.right = right
  136.  
  137.     def value(self, var: Dict[str, float]) -> float:
  138.         return self.left.value(var) + self.right.value(var)
  139.  
  140.     def _proper_deriv(self, sub: str) -> Expression:
  141.         return Constant(1)
  142.  
  143.     def reduce(self):
  144.         le = self.left.reduce()
  145.         ri = self.right.reduce()
  146.         if le == Constant(0):
  147.             return ri
  148.         if ri == Constant(0):
  149.             return le
  150.         return (le + ri)._const_reduce()
  151.  
  152.     def __str__(self):
  153.         return f"({self.left} + {self.right})"
  154.  
  155.  
  156. class Mul(Expression):
  157.  
  158.     left: Expression
  159.     right: Expression
  160.  
  161.     def __init__(self, left: Expression, right: Expression):
  162.         self.subexprs = {
  163.             'left': left,
  164.             'right': right
  165.         }
  166.         self.left = left
  167.         self.right = right
  168.  
  169.     def value(self, var: Dict[str, float]) -> float:
  170.         return self.left.value(var) * self.right.value(var)
  171.  
  172.     def _proper_deriv(self, sub: str) -> Expression:
  173.         return self.right if sub == 'left' else self.left
  174.  
  175.     def reduce(self):
  176.         le = self.left.reduce()
  177.         ri = self.right.reduce()
  178.         if le == Constant(0) or ri == Constant(0):
  179.             return Constant(0)
  180.         if le == Constant(1):
  181.             return ri
  182.         if ri == Constant(1):
  183.             return le
  184.         return (le * ri)._const_reduce()
  185.  
  186.     def __str__(self):
  187.         return f"({self.left} * {self.right})"
  188.  
  189.  
  190. class Log(Expression):
  191.  
  192.     x: Expression
  193.     base: Expression
  194.  
  195.     def __init__(self, x: Expression, base: Expression = Variable('e')):
  196.         self.subexprs = {
  197.             'x': x,
  198.             'base': base
  199.         }
  200.         self.x = x
  201.         self.base = base
  202.  
  203.     def value(self, var: Dict[str, float]) -> float:
  204.         return log(self.x.value(var), self.base.value(var))
  205.  
  206.     def _proper_deriv(self, sub: str) -> Expression:
  207.         if sub == 'x':
  208.             return Constant(1) / (self.x * Log(self.base))
  209.         if sub == 'base':
  210.             return -(self / (self.base * Log(self.base)))
  211.  
  212.     def reduce(self):
  213.         b = self.base.reduce()
  214.         xi = self.x.reduce()
  215.         if isinstance(xi, Pow) and xi.base == b:
  216.             return xi.power
  217.         if b == xi:
  218.             return Constant(1)
  219.         return Log(xi, b)._const_reduce()
  220.  
  221.     def __str__(self):
  222.         return f"log_{self.base}({self.x})"
  223.  
  224.  
  225. class Pow(Expression):
  226.  
  227.     base: Expression
  228.     power: Expression
  229.  
  230.     def __init__(self, base: Expression, power: Expression):
  231.         self.subexprs = {
  232.             'base': base,
  233.             'power': power
  234.         }
  235.         self.base = base
  236.         self.power = power
  237.  
  238.     def value(self, var: Dict[str, float]) -> float:
  239.         return np.power(self.base.value(var), self.power.value(var))
  240.  
  241.     def _proper_deriv(self, sub: str) -> Expression:
  242.         if sub == 'base':
  243.             return self.power * (self.base ** (self.power - 1))
  244.         if sub == 'power':
  245.             return Log(self.base) * self
  246.  
  247.     def reduce(self):
  248.         b = self.base.reduce()
  249.         p = self.power.reduce()
  250.         if isinstance(p, Log) and p.base == b:
  251.             return p.x
  252.         return (b ** p)._const_reduce()
  253.  
  254.     def __str__(self):
  255.         return f"({self.base})^({self.power})"
  256.  
  257.  
  258. def main():
  259.     e = Log(Variable('x') * 'y' + 'x')
  260.     print(e)
  261.     print(e.value({'x': 5, 'y': 3, 'e': np.e}))
  262.     print(e.deriv('x').reduce())
  263.     pass
  264.  
  265.  
  266. if __name__ == '__main__':
  267.     main()
  268.  
RAW Paste Data