Advertisement
theosib

C++ derivative library

Mar 27th, 2022
1,441
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 23.95 KB | None | 0 0
  1. #include <memory>
  2. #include <iostream>
  3. #include <string>
  4. #include <unordered_map>
  5. #include <map>
  6. #include <math.h>
  7. #include <vector>
  8. #include <initializer_list>
  9.  
  10. struct NodeObj;
  11. typedef std::shared_ptr<NodeObj> Node;
  12.  
  13. Node NodeNull(0);
  14.  
  15. class NodeObj {
  16. protected:
  17.     NodeObj() {}
  18. public:
  19.     virtual Node unwrap() const;
  20.     virtual bool isNeg() const;
  21.     virtual bool isRecip() const;
  22.     virtual Node negate() const;
  23.     virtual Node reciprocate() const;
  24.     virtual bool isConst() const;
  25.     virtual double getConst() const;
  26.     virtual bool isLog() const;
  27.     virtual bool isExp() const;
  28.     virtual bool isPower() const;
  29.     virtual bool isSum() const;
  30.     virtual bool isDiff() const;
  31. public:
  32.     virtual Node derivative(const Node& wrt) = 0;
  33.     virtual void print(std::ostream& os) const = 0;
  34.     virtual ~NodeObj() {}
  35. };
  36.  
  37. class Const : public NodeObj {
  38. private:
  39.     double value;
  40.     static std::map<double, Node> intern;
  41. protected:
  42.     Const(double v) : value(v) {}
  43. public:
  44.     virtual Node derivative(const Node& wrt);
  45.     static Node make(double v);
  46.     virtual void print(std::ostream& os) const;
  47.     virtual Node negate() const;
  48.     virtual Node reciprocate() const;
  49.     virtual bool isConst() const;
  50.     virtual double getConst() const;
  51. };
  52.  
  53. inline
  54. Node number(double n) {
  55.     return Const::make(n);
  56. }
  57.  
  58. class Variable : public NodeObj {
  59. private:
  60.     std::string name;
  61.     static std::unordered_map<std::string, Node> intern;
  62. protected:
  63.     Variable(const std::string& n) : name(n) {}
  64. public:
  65.     virtual Node derivative(const Node& wrt);
  66.     static Node make(const std::string& n);
  67.     virtual void print(std::ostream& os) const;
  68. };
  69.  
  70. inline
  71. Node var(const std::string& n) {
  72.     return Variable::make(n);
  73. }
  74.  
  75. class Diff;
  76. class Sum : public NodeObj {
  77. private:
  78.     std::vector<Node> a;
  79.     Node d;
  80. protected:
  81.     Sum() {}
  82.     Sum(const Node& p, const Node& q) {
  83.         a.push_back(p);
  84.         a.push_back(q);
  85.     }
  86.     virtual bool isSum() const;
  87. public:
  88.     virtual Node derivative(const Node& wrt);
  89.     static Node make(const Node& p, const Node& q);
  90.     static Node make(std::initializer_list<Node> l);
  91.     static Node make(const std::vector<Node>& l);
  92.     template<typename It> static Node make(It start, It end);
  93.     virtual void print(std::ostream& os) const;
  94.    
  95.     friend class Diff;
  96. };
  97.  
  98. inline
  99. Node& operator+=(Node& a1, const Node& a2) {
  100.     Node n = Sum::make(a1, a2);
  101.     a1.swap(n);
  102.     return a1;
  103. }
  104.  
  105. inline
  106. Node operator +(const Node& p, const Node& q) {
  107.     return Sum::make(p, q);
  108. }
  109.  
  110. inline
  111. Node sum(std::initializer_list<Node> l)
  112. {
  113.     return Sum::make(l);
  114. }
  115.  
  116. inline
  117. Node sum(const std::vector<Node>& l)
  118. {
  119.     return Sum::make(l);
  120. }
  121.  
  122.  
  123. class Neg : public NodeObj {
  124. private:
  125.     Node a;
  126.     Node d;
  127. protected:
  128.     Neg(const Node& p) : a(p) {}
  129.     virtual Node unwrap() const;
  130.     virtual bool isNeg() const;
  131. public:
  132.     virtual Node derivative(const Node& wrt);
  133.     static Node make(const Node& p);
  134.     virtual void print(std::ostream& os) const;
  135.     virtual Node negate() const;
  136. };
  137.  
  138. inline
  139. Node operator -(const Node& p) {
  140.     return Neg::make(p);
  141. }
  142.  
  143. class Log : public NodeObj {
  144. private:
  145.     Node a;
  146.     Node d;
  147. protected:
  148.     Log(const Node& p) : a(p) {}
  149.     virtual Node unwrap() const;
  150.     virtual bool isLog() const;
  151. public:
  152.     virtual Node derivative(const Node& wrt);
  153.     static Node make(const Node& p);
  154.     virtual void print(std::ostream& os) const;
  155. };
  156.  
  157. inline
  158. Node log(const Node& p) {
  159.     return Log::make(p);
  160. }
  161.  
  162. class Exp : public NodeObj {
  163. private:
  164.     Node a;
  165.     Node d;
  166. protected:
  167.     Exp(const Node& p) : a(p) {}
  168.     virtual Node unwrap() const;
  169.     virtual bool isExp() const;
  170. public:
  171.     virtual Node derivative(const Node& wrt);
  172.     static Node make(const Node& p);
  173.     virtual void print(std::ostream& os) const;
  174. };
  175.  
  176. inline
  177. Node exp(const Node& p) {
  178.     return Exp::make(p);
  179. }
  180.  
  181. class Sin : public NodeObj {
  182. private:
  183.     Node a;
  184.     Node d;
  185. protected:
  186.     Sin(const Node& p) : a(p) {}
  187.     virtual Node unwrap() const;
  188. public:
  189.     virtual Node derivative(const Node& wrt);
  190.     static Node make(const Node& p);
  191.     virtual void print(std::ostream& os) const;
  192. };
  193.  
  194. inline
  195. Node sin(const Node& p) {
  196.     return Sin::make(p);
  197. }
  198.  
  199. class Cos : public NodeObj {
  200. private:
  201.     Node a;
  202.     Node d;
  203. protected:
  204.     Cos(const Node& p) : a(p) {}
  205.     virtual Node unwrap() const;
  206. public:
  207.     virtual Node derivative(const Node& wrt);
  208.     static Node make(const Node& p);
  209.     virtual void print(std::ostream& os) const;
  210. };
  211.  
  212. inline
  213. Node cos(const Node& p) {
  214.     return Cos::make(p);
  215. }
  216.  
  217. class Tan : public NodeObj {
  218. private:
  219.     Node a;
  220.     Node d;
  221. protected:
  222.     Tan(const Node& p) : a(p) {}
  223.     virtual Node unwrap() const;
  224. public:
  225.     virtual Node derivative(const Node& wrt);
  226.     static Node make(const Node& p);
  227.     virtual void print(std::ostream& os) const;
  228. };
  229.  
  230. inline
  231. Node tan(const Node& p) {
  232.     return Tan::make(p);
  233. }
  234.  
  235. class Reciprocal : public NodeObj {
  236. private:
  237.     Node a;
  238.     Node d;
  239. protected:
  240.     Reciprocal(const Node& p) : a(p) {}
  241.     virtual Node unwrap() const;
  242.     virtual bool isRecip() const;
  243.     virtual Node reciprocate() const;
  244. public:
  245.     virtual Node derivative(const Node& wrt);
  246.     static Node make(const Node& p);
  247.     virtual void print(std::ostream& os) const;
  248. };
  249.  
  250. inline
  251. Node recip(const Node& p) {
  252.     return Reciprocal::make(p);
  253. }
  254.  
  255. class Diff : public NodeObj {
  256. private:
  257.     Node a, b;
  258.     Node d;
  259. protected:
  260.     Diff(const Node& p, const Node& q) : a(p), b(q) {}
  261.     bool isDiff() const;
  262. public:
  263.     virtual Node derivative(const Node& wrt);
  264.     static Node make(const Node& p, const Node& q);
  265.     virtual void print(std::ostream& os) const;
  266.     virtual Node negate() const;
  267.    
  268.     friend class Sum;
  269. };
  270.  
  271. inline
  272. Node operator -(const Node& p, const Node& q) {
  273.     return Diff::make(p, q);
  274. }
  275.  
  276. inline
  277. Node& operator-=(Node& a1, const Node& a2) {
  278.     Node n = Diff::make(a1, a2);
  279.     a1.swap(n);
  280.     return a1;
  281. }
  282.  
  283. class Product : public NodeObj {
  284. private:
  285.     Node a, b;
  286.     Node d;
  287. protected:
  288.     Product(const Node& p, const Node& q) : a(p), b(q) {}
  289. public:
  290.     virtual Node derivative(const Node& wrt);
  291.     static Node make(const Node& p, const Node& q);
  292.     virtual void print(std::ostream& os) const;
  293. };
  294.  
  295. inline
  296. Node operator *(const Node& p, const Node& q) {
  297.     return Product::make(p, q);
  298. }
  299.  
  300. inline
  301. Node& operator*=(Node& a1, const Node& a2) {
  302.     Node n = Product::make(a1, a2);
  303.     a1.swap(n);
  304.     return a1;
  305. }
  306.  
  307. class Power : public NodeObj {
  308. private:
  309.     Node a, b;
  310.     Node d;
  311. protected:
  312.     Power(const Node& p, const Node& q) : a(p), b(q) {}
  313.     virtual Node reciprocate() const;
  314.     virtual bool isPower() const;
  315. public:
  316.     virtual Node derivative(const Node& wrt);
  317.     static Node make(const Node& p, const Node& q);
  318.     virtual void print(std::ostream& os) const;
  319. };
  320.  
  321. inline
  322. Node pow(const Node& p, const Node& q) {
  323.     return Power::make(p, q);
  324. }
  325.  
  326. class Quotient : public NodeObj {
  327. private:
  328.     Node a, b;
  329.     Node d;
  330. protected:
  331.     virtual Node reciprocate() const;
  332.     Quotient(const Node& p, const Node& q) : a(p), b(q) {}
  333. public:
  334.     virtual Node derivative(const Node& wrt);
  335.     static Node make(const Node& p, const Node& q);
  336.     virtual void print(std::ostream& os) const;
  337. };
  338.  
  339. inline
  340. Node operator /(const Node& p, const Node& q) {
  341.     return Quotient::make(p, q);
  342. }
  343.  
  344. inline
  345. Node& operator/=(Node& a1, const Node& a2) {
  346.     Node n = Quotient::make(a1, a2);
  347.     a1.swap(n);
  348.     return a1;
  349. }
  350.  
  351.  
  352.  
  353.  
  354.  
  355. bool isPowerOfTwo(double n)
  356. {
  357.    if (n==0) return false;
  358.    return (ceil(log2(n)) == floor(log2(n)));
  359. }
  360.  
  361. /* Node */
  362.  
  363. std::ostream& operator<<(std::ostream& os, const NodeObj& n)
  364. {
  365.     n.print(os);
  366.     return os;
  367. }
  368.  
  369. Node NodeObj::unwrap() const
  370. {
  371.     return NodeNull;
  372. }
  373.  
  374. bool NodeObj::isNeg() const
  375. {
  376.     return false;
  377. }
  378.  
  379. bool NodeObj::isRecip() const
  380. {
  381.     return false;
  382. }
  383.  
  384. Node NodeObj::negate() const
  385. {
  386.     return NodeNull;
  387. }
  388.  
  389. Node NodeObj::reciprocate() const
  390. {
  391.     return NodeNull;
  392. }
  393.  
  394. bool NodeObj::isConst() const
  395. {
  396.     return false;
  397. }
  398.  
  399. double NodeObj::getConst() const
  400. {
  401.     return 0;
  402. }
  403.  
  404. bool NodeObj::isLog() const
  405. {
  406.     return false;
  407. }
  408.  
  409. bool NodeObj::isExp() const
  410. {
  411.     return false;
  412. }
  413.  
  414. bool NodeObj::isPower() const
  415. {
  416.     return false;
  417. }
  418.  
  419. bool NodeObj::isSum() const
  420. {
  421.     return false;
  422. }
  423.  
  424. bool NodeObj::isDiff() const
  425. {
  426.     return false;
  427. }
  428.  
  429. /* Const */
  430.  
  431. std::map<double, Node> Const::intern;
  432.  
  433. Node Const::make(double v)
  434. {
  435.     std::map<double, Node>::iterator ii = intern.find(v);
  436.     if (ii != intern.end()) return ii->second;
  437.     Node p(new Const(v));
  438.     intern[v] = p;
  439.     return p;
  440. }
  441.  
  442. Node Const::negate() const
  443. {
  444.     return Const::make(-value);
  445. }
  446.  
  447. void Const::print(std::ostream& os) const
  448. {
  449.     os << value;
  450. }
  451.  
  452. Node ConstZero = Const::make(0);
  453. Node ConstOne  = Const::make(1);
  454. Node ConstTwo  = Const::make(2);
  455. Node ConstNeg  = Const::make(-1);
  456. Node ConstInf  = Const::make(1.0 / 0.0);
  457.  
  458. Node Const::reciprocate() const
  459. {
  460.     if (value == 0) return ConstInf;
  461.     if (value == 1) return ConstOne;
  462.     if (isPowerOfTwo(value)) return Const::make(1.0 / value);
  463.     if (isinf(value)) return ConstZero;
  464.     return NodeNull;
  465. }
  466.  
  467. bool Const::isConst() const
  468. {
  469.     return true;
  470. }
  471.  
  472. double Const::getConst() const
  473. {
  474.     return value;
  475. }
  476.  
  477. Node Const::derivative(const Node& wrt)
  478. {
  479.     return ConstZero;
  480. }
  481.  
  482.  
  483.  
  484. /* Variable */
  485.  
  486. std::unordered_map<std::string, Node> Variable::intern;
  487.  
  488. Node Variable::make(const std::string& n)
  489. {
  490.     std::unordered_map<std::string, Node>::iterator ii = intern.find(n);
  491.     if (ii != intern.end()) return ii->second;
  492.     Node p(new Variable(n));
  493.     intern[n] = p;
  494.     return p;
  495. }
  496.  
  497. void Variable::print(std::ostream& os) const
  498. {
  499.     os << name;
  500. }
  501.  
  502. Node Variable::derivative(const Node& wrt)
  503. {
  504.     if (wrt.get() == this) return ConstOne;
  505.     return ConstZero;
  506. }
  507.  
  508.  
  509. /* Sum */
  510.  
  511. Node Sum::make(const Node& p, const Node& q)
  512. {
  513.     if (p->isConst() && q->isConst()) {
  514.         return Const::make(p->getConst() + q->getConst());
  515.     }
  516.     if (p == ConstZero) return q;
  517.     if (q == ConstZero) return p;
  518.     if (p == q) return Product::make(ConstTwo, p);
  519.    
  520.     if (p->isSum() || p->isDiff() || q->isSum() || q->isDiff()) {
  521.         std::vector<Node> addends;
  522.         if (p->isSum()) {
  523.             Sum *s = static_cast<Sum*>(p.get());
  524.             addends.insert(addends.end(), s->a.begin(), s->a.end());
  525.         } else if (p->isDiff()) {
  526.             Diff *s = static_cast<Diff*>(p.get());
  527.             addends.push_back(s->a);
  528.             addends.push_back(Neg::make(s->b));
  529.         } else {
  530.             addends.push_back(p);
  531.         }
  532.         if (q->isSum()) {
  533.             Sum *s = static_cast<Sum*>(q.get());
  534.             addends.insert(addends.end(), s->a.begin(), s->a.end());
  535.         } else if (q->isDiff()) {
  536.             Diff *s = static_cast<Diff*>(q.get());
  537.             addends.push_back(s->a);
  538.             addends.push_back(Neg::make(s->b));
  539.         } else {
  540.             addends.push_back(q);
  541.         }
  542.         return Sum::make(addends);
  543.     }
  544.    
  545.     if (p->isNeg() && q->isNeg()) return Neg::make(Sum::make(p->unwrap(), q->unwrap()));
  546.     if (p->isNeg()) return Diff::make(q, p->unwrap());
  547.     if (q->isNeg()) return Diff::make(p, q->unwrap());
  548.    
  549.     return Node(new Sum(p, q));
  550. }
  551.  
  552. template<typename It>
  553. Node Sum::make(It first, It last)
  554. {    
  555.     double const_total = 0;
  556.     Sum *s = new Sum();
  557.     for (auto it = first; it != last; ++it) {
  558.         const Node& n(*it);
  559.         if(n->isConst()) {
  560.             const_total += n->getConst();
  561.         } else {
  562.             s->a.push_back(n);
  563.         }
  564.     }
  565.     if (const_total != 0) {
  566.         s->a.push_back(Const::make(const_total));
  567.     }
  568.    
  569.     if (s->a.size() == 0) {
  570.         delete s;
  571.         return ConstZero;
  572.     } else if (s->a.size() == 1) {
  573.         Node n(s->a[0]);
  574.         delete s;
  575.         return n;
  576.     } else if (s->a.size() == 2) {
  577.         Node n(Sum::make(s->a[0], s->a[1]));
  578.         delete s;
  579.         return n;
  580.     } else {
  581.         return Node(s);
  582.     }
  583. }
  584.  
  585. Node Sum::make(std::initializer_list<Node> l)
  586. {
  587.     return Sum::make(l.begin(), l.end());
  588. }
  589.  
  590. Node Sum::make(const std::vector<Node>& l)
  591. {
  592.     return Sum::make(l.begin(), l.end());
  593. }
  594.  
  595. bool Sum::isSum() const
  596. {
  597.     return true;
  598. }
  599.  
  600. void Sum::print(std::ostream& os) const
  601. {
  602.     bool first = true;
  603.     os << '(';
  604.     for (const Node& n : a) {
  605.         if (!first) os << " + ";
  606.         first = false;
  607.         os << (*n);
  608.     }
  609.     os << ')';
  610.     // os << '(' << (*a) << " + " << (*b) << ')';
  611. }
  612.  
  613. Node Sum::derivative(const Node& wrt)
  614. {
  615.     if (d) return d;
  616.     std::vector<Node> v;
  617.     for (const Node& n : a) {
  618.         v.push_back(n->derivative(wrt));
  619.     }
  620.     return Sum::make(v);
  621.     // Node p = a->derivative(wrt);
  622.     // Node q = b->derivative(wrt);
  623.     // return (d = Sum::make(p, q));
  624. }
  625.  
  626.  
  627. /* Diff */
  628.  
  629. Node Diff::make(const Node& p, const Node& q)
  630. {
  631.     if (p->isConst() && q->isConst()) {
  632.         return Const::make(p->getConst() - q->getConst());
  633.     }
  634.     if (p == ConstZero) return Neg::make(q);
  635.     if (q == ConstZero) return p;
  636.     if (p == q) return ConstZero;
  637.    
  638.     if (p->isSum() || p->isDiff() || q->isSum() || q->isDiff()) {
  639.         std::vector<Node> addends;
  640.         if (p->isSum()) {
  641.             Sum *s = static_cast<Sum*>(p.get());
  642.             addends.insert(addends.end(), s->a.begin(), s->a.end());
  643.         } else if (p->isDiff()) {
  644.             Diff *s = static_cast<Diff*>(p.get());
  645.             addends.push_back(s->a);
  646.             addends.push_back(Neg::make(s->b));
  647.         } else {
  648.             addends.push_back(p);
  649.         }
  650.         if (q->isSum()) {
  651.             Sum *s = static_cast<Sum*>(q.get());
  652.             for (const Node& n : s->a) {
  653.                 addends.push_back(Neg::make(n));
  654.             }
  655.         } else if (q->isDiff()) {
  656.             Diff *s = static_cast<Diff*>(q.get());
  657.             addends.push_back(Neg::make(s->a));
  658.             addends.push_back(s->b);
  659.         } else {
  660.             addends.push_back(Neg::make(q));
  661.         }
  662.         return Sum::make(addends);
  663.     }
  664.        
  665.     if (p->isNeg() && q->isNeg()) return Diff::make(q->unwrap(), p->unwrap());
  666.     if (q->isNeg()) return Sum::make(p, q->unwrap());
  667.     if (p->isNeg()) return Neg::make(Sum::make(p->unwrap(), q));
  668.    
  669.     return Node(new Diff(p, q));
  670. }
  671.  
  672. bool Diff::isDiff() const
  673. {
  674.     return true;
  675. }
  676.  
  677. Node Diff::negate() const
  678. {
  679.     return Diff::make(b, a);
  680. }
  681.  
  682. void Diff::print(std::ostream& os) const
  683. {
  684.     os << '(' << (*a) << " - " << (*b) << ')';
  685. }
  686.  
  687. Node Diff::derivative(const Node& wrt)
  688. {
  689.     if (d) return d;
  690.     Node p = a->derivative(wrt);
  691.     Node q = b->derivative(wrt);
  692.     return (d = Diff::make(p, q));
  693. }
  694.  
  695.  
  696. /* Neg */
  697.  
  698. Node Neg::make(const Node& p)
  699. {
  700.     Node neg = p->negate();
  701.     if (neg) return neg;
  702.     return Node(new Neg(p));
  703. }
  704.  
  705. void Neg::print(std::ostream& os) const
  706. {
  707.     os << '-' << (*a);
  708. }
  709.  
  710. Node Neg::derivative(const Node& wrt)
  711. {
  712.     if (d) return d;
  713.     Node p = a->derivative(wrt);
  714.     return (d = Neg::make(p));
  715. }
  716.  
  717. Node Neg::unwrap() const
  718. {
  719.     return a;
  720. }
  721.  
  722. Node Neg::negate() const
  723. {
  724.     return a;
  725. }
  726.  
  727. bool Neg::isNeg() const
  728. {
  729.     return true;
  730. }
  731.  
  732.  
  733. /* Log */
  734.  
  735. Node Log::make(const Node& p)
  736. {
  737.     if (p->isExp()) return p->unwrap();
  738.     /// XXX handle constant
  739.     return Node(new Log(p));
  740. }
  741.  
  742. void Log::print(std::ostream& os) const
  743. {
  744.     os << "log" << (*a);
  745. }
  746.  
  747. Node Log::derivative(const Node& wrt)
  748. {
  749.     if (d) return d;
  750.     Node p = a->derivative(wrt);
  751.     return (d = Quotient::make(p, a));
  752. }
  753.  
  754. Node Log::unwrap() const
  755. {
  756.     return a;
  757. }
  758.  
  759. bool Log::isLog() const
  760. {
  761.     return true;
  762. }
  763.  
  764.  
  765. /* Exp */
  766.  
  767. Node Exp::make(const Node& p)
  768. {
  769.     if (p->isLog()) return p->unwrap();
  770.     /// XXX handle constant
  771.     return Node(new Exp(p));
  772. }
  773.  
  774. void Exp::print(std::ostream& os) const
  775. {
  776.     os << "exp" << (*a);
  777. }
  778.  
  779. Node Exp::derivative(const Node& wrt)
  780. {
  781.     if (d) return d;
  782.     Node p = a->derivative(wrt);
  783.     Node e = Exp::make(a);
  784.     return (d = Product::make(p, e));
  785. }
  786.  
  787. Node Exp::unwrap() const
  788. {
  789.     return a;
  790. }
  791.  
  792. bool Exp::isExp() const
  793. {
  794.     return true;
  795. }
  796.  
  797.  
  798. /* Reciprocal */
  799.  
  800. Node Reciprocal::make(const Node& p)
  801. {
  802.     Node rec = p->reciprocate();
  803.     if (rec) return rec;
  804.     if (p->isRecip()) return p->unwrap();
  805.     return Node(new Reciprocal(p));
  806. }
  807.  
  808. void Reciprocal::print(std::ostream& os) const
  809. {
  810.     os << "1/" << (*a);
  811. }
  812.  
  813. Node Reciprocal::derivative(const Node& wrt)
  814. {
  815.     if (d) return d;
  816.     Node ap = a->derivative(wrt);
  817.     return (d = Quotient::make(ap, Product::make(ConstTwo, a)));
  818. }
  819.  
  820. Node Reciprocal::unwrap() const
  821. {
  822.     return a;
  823. }
  824.  
  825. Node Reciprocal::reciprocate() const
  826. {
  827.     return a;
  828. }
  829.  
  830. bool Reciprocal::isRecip() const
  831. {
  832.     return true;
  833. }
  834.  
  835.  
  836. /* Product */
  837.  
  838. Node Product::make(const Node& p, const Node& q)
  839. {
  840.     if (p->isConst() && q->isConst()) {
  841.         return Const::make(p->getConst() * q->getConst());
  842.     }
  843.     if (p == ConstZero) return ConstZero;
  844.     if (q == ConstZero) return ConstZero;
  845.     if (p == ConstOne) return q;
  846.     if (q == ConstOne) return p;
  847.     if (p == ConstNeg) return Neg::make(q);
  848.     if (q == ConstNeg) return Neg::make(p);
  849.    
  850.     if (p->isRecip() && q->isRecip()) return Reciprocal::make(Product::make(p->unwrap(), q->unwrap()));
  851.     if (p->isRecip()) return Quotient::make(q, p);
  852.     if (q->isRecip()) return Quotient::make(p, q);
  853.        
  854.     return Node(new Product(p, q));
  855. }
  856.  
  857. void Product::print(std::ostream& os) const
  858. {
  859.     os << '(' << (*a) << " * " << (*b) << ')';
  860. }
  861.  
  862. Node Product::derivative(const Node& wrt)
  863. {
  864.     if (d) return d;
  865.     Node ap = a->derivative(wrt);
  866.     Node bp = b->derivative(wrt);
  867.     Node r = Product::make(a, bp);
  868.     Node s = Product::make(b, ap);
  869.     return (d = Sum::make(r, s));
  870. }
  871.  
  872.  
  873. /* Quotient */
  874.  
  875. Node Quotient::make(const Node& p, const Node& q)
  876. {
  877.     // if (p->isConst() && q->isConst()) {
  878.     //     return Const::make(p->getConst() / q->getConst());
  879.     // }
  880.     if (p == ConstZero) return ConstZero;
  881.     if (q == ConstZero) return ConstInf;
  882.     if (p == ConstOne) return Reciprocal::make(q);
  883.     if (q == ConstOne) return p;
  884.     if (p == ConstNeg) return Neg::make(Reciprocal::make(q));
  885.     if (q == ConstNeg) return Neg::make(p);
  886.    
  887.     if (p->isRecip() && q->isRecip()) return Quotient::make(q->unwrap(), p->unwrap());
  888.     if (p->isRecip()) return Reciprocal::make(Product::make(p->unwrap(), q));
  889.     if (q->isRecip()) return Product::make(p, q->unwrap());
  890.    
  891.     // Negate power
  892.    
  893.     return Node(new Quotient(p, q));
  894. }
  895.  
  896. void Quotient::print(std::ostream& os) const
  897. {
  898.     os << '(' << (*a) << " / " << (*b) << ')';
  899. }
  900.  
  901. Node Quotient::derivative(const Node& wrt)
  902. {
  903.     if (d) return d;
  904.     Node ap = a->derivative(wrt);
  905.     Node bp = b->derivative(wrt);
  906.     Node r = Product::make(b, ap);
  907.     Node s = Product::make(a, bp);
  908.     Node dif = Diff::make(r, s);
  909.     Node sqr = Power::make(b, ConstTwo);
  910.     return (d = Quotient::make(dif, sqr));
  911. }
  912.  
  913. Node Quotient::reciprocate() const
  914. {
  915.     return Quotient::make(b, a);
  916. }
  917.  
  918.  
  919. /* Power */
  920.  
  921. Node Power::make(const Node& p, const Node& q)
  922. {
  923.     // if (p->isConst() && q->isConst()) {
  924.     //     return Const::make(pow(p->getConst(), q->getConst()));
  925.     // }
  926.     if (p == ConstZero) return ConstZero;
  927.     if (q == ConstZero) return ConstOne;
  928.     if (p == ConstOne) return ConstOne;
  929.     if (q == ConstOne) return p;
  930.     if (q == ConstNeg) return Reciprocal::make(p);
  931.    
  932.     // Multiply powers
  933.     if (p->isPower()) {
  934.         Power *pw = static_cast<Power*>(p.get());
  935.         return Power::make(pw->a, Product::make(pw->b, q));
  936.     }
  937.    
  938.     if (p->isRecip()) return Reciprocal::make(Power::make(p->unwrap(), q));
  939.        
  940.     return Node(new Power(p, q));
  941. }
  942.  
  943. Node Power::reciprocate() const
  944. {
  945.     return Power::make(a, Neg::make(b));
  946. }
  947.  
  948. void Power::print(std::ostream& os) const
  949. {
  950.     os << '(' << (*a) << " ^ " << (*b) << ')';
  951. }
  952.  
  953. Node Power::derivative(const Node& wrt)
  954. {
  955.     if (d) return d;
  956.  
  957.     // g(x) f(x)^(g(x) - 1) f'(x) + f(x)^g(x) log(f(x)) g'(x)
  958.    
  959.     Node ap = a->derivative(wrt);
  960.     Node bp = b->derivative(wrt);
  961.     Node bm1 = Diff::make(b, ConstOne);
  962.     Node lg = Log::make(a);
  963.    
  964.     Node p1 = Power::make(a, bm1);
  965.     Node n1 = Product::make(Product::make(b, p1), ap);
  966.    
  967.     Node p2 = Power::make(a, b);
  968.     Node n2 = Product::make(Product::make(p2, lg), bp);
  969.    
  970.     return (d = Sum::make(n1, n2));
  971. }
  972.  
  973. bool Power::isPower() const
  974. {
  975.     return true;
  976. }
  977.  
  978.  
  979. /* Sin */
  980.  
  981. Node Sin::make(const Node& p)
  982. {
  983.     return Node(new Sin(p));
  984. }
  985.  
  986. void Sin::print(std::ostream& os) const
  987. {
  988.     os << "sin" << (*a);
  989. }
  990.  
  991. Node Sin::derivative(const Node& wrt)
  992. {
  993.     if (d) return d;
  994.     Node p = a->derivative(wrt);
  995.     Node c = Cos::make(a);
  996.     return (d = Product::make(c, p));
  997. }
  998.  
  999. Node Sin::unwrap() const
  1000. {
  1001.     return a;
  1002. }
  1003.  
  1004.  
  1005. /* Cos */
  1006.  
  1007. Node Cos::make(const Node& p)
  1008. {
  1009.     return Node(new Cos(p));
  1010. }
  1011.  
  1012. void Cos::print(std::ostream& os) const
  1013. {
  1014.     os << "cos" << (*a);
  1015. }
  1016.  
  1017. Node Cos::derivative(const Node& wrt)
  1018. {
  1019.     if (d) return d;
  1020.     Node p = a->derivative(wrt);
  1021.     Node c = Sin::make(a);
  1022.     return (d = Product::make(c, p));
  1023. }
  1024.  
  1025. Node Cos::unwrap() const
  1026. {
  1027.     return a;
  1028. }
  1029.  
  1030.  
  1031. /* Tan */
  1032.  
  1033. Node Tan::make(const Node& p)
  1034. {
  1035.     return Node(new Tan(p));
  1036. }
  1037.  
  1038. void Tan::print(std::ostream& os) const
  1039. {
  1040.     os << "tan" << (*a);
  1041. }
  1042.  
  1043. Node Tan::derivative(const Node& wrt)
  1044. {
  1045.     if (d) return d;
  1046.     Node p = a->derivative(wrt);
  1047.     Node c = Power::make(Cos::make(a), Const::make(-2));
  1048.     return (d = Product::make(p, c));
  1049. }
  1050.  
  1051. Node Tan::unwrap() const
  1052. {
  1053.     return a;
  1054. }
  1055.  
  1056.  
  1057. int main()
  1058. {
  1059.     // Node n = Product::make(Variable::make("x"), ConstTwo);
  1060.     // Node n = Diff::make(Variable::make("x"), Neg::make(Variable::make("x")));
  1061.     // Node n = Product::make(Reciprocal::make(Variable::make("x")), Reciprocal::make(Variable::make("y")));
  1062.     // Node n = Power::make(Reciprocal::make(Const::make(3)), Reciprocal::make(Const::make(5)));
  1063.     // Node n = Power::make(Variable::make("x"), Variable::make("y"));
  1064.     // Node m = Power::make(n, Variable::make("z"));
  1065.     //Node n = Sin::make(Power::make(Variable::make("x"), ConstTwo));
  1066.     // Node n = Product::make(Reciprocal::make(Const::make(3)), Const::make(3));
  1067.     // std::vector<Node> terms;
  1068.     // Node x = Variable::make("x");
  1069.     // terms.push_back(Power::make(x, Const::make(4)));
  1070.     // terms.push_back(Power::make(x, Const::make(3)));
  1071.     // terms.push_back(Power::make(x, Const::make(2)));
  1072.     // terms.push_back(x);
  1073.     // terms.push_back(Const::make(10));
  1074.     // terms.push_back(Sin::make(Power::make(x, Const::make(5))));
  1075.     // Node n = Sum::make(terms);
  1076.    
  1077.     // std::vector<Node> terms;
  1078.     // Node x = var("x");
  1079.     // terms.push_back(pow(x, number(4)));
  1080.     // terms.push_back(pow(x, number(3)));
  1081.     // terms.push_back(pow(x, number(2)));
  1082.     // terms.push_back(x);
  1083.     // terms.push_back(number(10));
  1084.     // terms.push_back(sin(pow(x, number(5))));
  1085.     // Node n = sum(terms);
  1086.    
  1087.     Node x = var("x");
  1088.     Node y = var("y");
  1089.     Node z = var("z");
  1090.     Node w = var("w");
  1091.     // Node a = x-y;
  1092.     // Node b = z-w;
  1093.     // Node n = a-b;
  1094.     Node n = ConstZero;
  1095.     n += x;
  1096.     n -= y;
  1097.     n -= (z-w);
  1098.    
  1099.    
  1100.     // Node n = ConstZero;
  1101.     // n += pow(x, number(4));
  1102.     // n += pow(x, number(3));
  1103.     // n += pow(x, number(2));
  1104.     // n += x;
  1105.     // n += number(10);
  1106.     // n += sin(pow(x, number(5)));
  1107.     // terms.push_back(pow(x, number(4)));
  1108.     // terms.push_back(pow(x, number(3)));
  1109.     // terms.push_back(pow(x, number(2)));
  1110.     // terms.push_back(x);
  1111.     // terms.push_back(number(10));
  1112.     // terms.push_back(sin(pow(x, number(5))));
  1113.     // Node n = sum(terms);
  1114.        
  1115.     std::cout << (*n) << std::endl;
  1116.     Node d = n->derivative(var("x"));
  1117.     std::cout << (*d) << std::endl;
  1118.     return 0;
  1119. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement