Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from __future__ import annotations
- from abc import ABC, abstractmethod
- from typing import Dict
- import numpy as np
- from math import log
- def literal_convert(x):
- if isinstance(x, Expression):
- return x
- if isinstance(x, int):
- return Constant(float(x))
- if isinstance(x, float):
- return Constant(x)
- if isinstance(x, str):
- try:
- return Constant(float(x))
- except ValueError:
- return Variable(x)
- raise ValueError(f"Can't convert ({x})@{type(x)} to expression")
- class Expression(ABC):
- subexprs: Dict[str, Expression]
- @abstractmethod
- def value(self, var: Dict[str, float]) -> float:
- raise NotImplemented
- @abstractmethod
- def _proper_deriv(self, sub: str) -> Expression:
- raise NotImplemented
- def deriv(self, by: str):
- d = Constant(0)
- for name, expr in self.subexprs.items():
- d += expr.deriv(by) * self._proper_deriv(name)
- return d
- def _const_reduce(self):
- try:
- return Constant(self.value({}))
- except ValueError:
- return self
- def reduce(self):
- return self._const_reduce()
- def __eq__(self, other):
- return type(self) is type(other) and all(other.subexprs[name] == e for name, e in self.subexprs.items())
- def __add__(self, other):
- return Plus(self, literal_convert(other))
- def __mul__(self, other):
- return Mul(self, literal_convert(other))
- def __neg__(self):
- return self * -1
- def __sub__(self, other):
- return self + (-other)
- def __pow__(self, power, modulo=None):
- if modulo is not None:
- raise NotImplemented
- return Pow(self, literal_convert(power))
- def __truediv__(self, other):
- return self * (other ** -1)
- class Constant(Expression):
- val: float
- def __init__(self, val: float):
- self.val = val
- self.subexprs = {}
- def value(self, var: Dict[str, float]) -> float:
- return self.val
- def _proper_deriv(self, sub: str) -> Expression:
- raise NotImplemented
- def deriv(self, var: str) -> Expression:
- return Constant(0)
- def __str__(self):
- return str(self.val)
- def __eq__(self, other):
- return super(Constant, self).__eq__(other) and self.val == other.val
- class Variable(Expression):
- name: str
- def __init__(self, name: str):
- self.name = name
- self.subexprs = {}
- def value(self, var: Dict[str, float]) -> float:
- if self.name not in var:
- raise ValueError(f"Variable {self.name} not found")
- return var[self.name]
- def _proper_deriv(self, sub: str) -> Expression:
- raise NotImplemented
- def deriv(self, var: str) -> Expression:
- return Constant(1 if var == self.name else 0)
- def __str__(self):
- return self.name
- def __eq__(self, other):
- return super(Variable, self).__eq__(other) and self.name == other.name
- class Plus(Expression):
- left: Expression
- right: Expression
- def __init__(self, left: Expression, right: Expression):
- self.subexprs = {
- 'left': left,
- 'right': right
- }
- self.left = left
- self.right = right
- def value(self, var: Dict[str, float]) -> float:
- return self.left.value(var) + self.right.value(var)
- def _proper_deriv(self, sub: str) -> Expression:
- return Constant(1)
- def reduce(self):
- le = self.left.reduce()
- ri = self.right.reduce()
- if le == Constant(0):
- return ri
- if ri == Constant(0):
- return le
- return (le + ri)._const_reduce()
- def __str__(self):
- return f"({self.left} + {self.right})"
- class Mul(Expression):
- left: Expression
- right: Expression
- def __init__(self, left: Expression, right: Expression):
- self.subexprs = {
- 'left': left,
- 'right': right
- }
- self.left = left
- self.right = right
- def value(self, var: Dict[str, float]) -> float:
- return self.left.value(var) * self.right.value(var)
- def _proper_deriv(self, sub: str) -> Expression:
- return self.right if sub == 'left' else self.left
- def reduce(self):
- le = self.left.reduce()
- ri = self.right.reduce()
- if le == Constant(0) or ri == Constant(0):
- return Constant(0)
- if le == Constant(1):
- return ri
- if ri == Constant(1):
- return le
- return (le * ri)._const_reduce()
- def __str__(self):
- return f"({self.left} * {self.right})"
- class Log(Expression):
- x: Expression
- base: Expression
- def __init__(self, x: Expression, base: Expression = Variable('e')):
- self.subexprs = {
- 'x': x,
- 'base': base
- }
- self.x = x
- self.base = base
- def value(self, var: Dict[str, float]) -> float:
- return log(self.x.value(var), self.base.value(var))
- def _proper_deriv(self, sub: str) -> Expression:
- if sub == 'x':
- return Constant(1) / (self.x * Log(self.base))
- if sub == 'base':
- return -(self / (self.base * Log(self.base)))
- def reduce(self):
- b = self.base.reduce()
- xi = self.x.reduce()
- if isinstance(xi, Pow) and xi.base == b:
- return xi.power
- if b == xi:
- return Constant(1)
- return Log(xi, b)._const_reduce()
- def __str__(self):
- return f"log_{self.base}({self.x})"
- class Pow(Expression):
- base: Expression
- power: Expression
- def __init__(self, base: Expression, power: Expression):
- self.subexprs = {
- 'base': base,
- 'power': power
- }
- self.base = base
- self.power = power
- def value(self, var: Dict[str, float]) -> float:
- return np.power(self.base.value(var), self.power.value(var))
- def _proper_deriv(self, sub: str) -> Expression:
- if sub == 'base':
- return self.power * (self.base ** (self.power - 1))
- if sub == 'power':
- return Log(self.base) * self
- def reduce(self):
- b = self.base.reduce()
- p = self.power.reduce()
- if isinstance(p, Log) and p.base == b:
- return p.x
- return (b ** p)._const_reduce()
- def __str__(self):
- return f"({self.base})^({self.power})"
- def main():
- e = Log(Variable('x') * 'y' + 'x')
- print(e)
- print(e.value({'x': 5, 'y': 3, 'e': np.e}))
- print(e.deriv('x').reduce())
- pass
- if __name__ == '__main__':
- main()
Advertisement
Add Comment
Please, Sign In to add comment