Advertisement
mwchase

LC:NN upload 3 - lib.rs

Mar 19th, 2017
127
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Rust 19.14 KB | None | 0 0
  1. extern crate num;
  2.  
  3. use std::cmp::Ordering;
  4. use std::collections::HashMap;
  5. use std::hash::{Hash, Hasher};
  6.  
  7. use num::{One, Zero};
  8.  
  9. mod node_defs;
  10. mod node_ordering;
  11.  
  12. use node_defs::{Container, container, NodeData, Sum, Product, Power, Primitive};
  13.  
  14. impl<T: NodeData> Sum<T> {
  15.     fn negate(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
  16.         env.make_sum(!self.minus, &self.constant, &self.terms)
  17.     }
  18. }
  19.  
  20. impl<T: NodeData> Product<T> {
  21.     fn negate(&self, env: &mut Environment<T>) -> Container<Product<T>> {
  22.         env.make_product(&-self.coefficient.clone(), &self.powers).unwrap()
  23.     }
  24. }
  25.  
  26. type SelfMap<T> = HashMap<T, Container<T>>;
  27.  
  28. struct Environment<T: NodeData> {
  29.     sum: SelfMap<Sum<T>>,
  30.     product: SelfMap<Product<T>>,
  31.     power: SelfMap<Power<T>>,
  32.     primitive: SelfMap<Primitive<T>>,
  33.     sum_derivative: HashMap<(Sum<T>, T::Parameter), Container<Sum<T>>>,
  34.     product_derivative: HashMap<(Product<T>, T::Parameter), Container<Sum<T>>>,
  35.     power_derivative: HashMap<(Power<T>, T::Parameter), Container<Sum<T>>>,
  36.     primitive_derivative: HashMap<(Primitive<T>, T::Parameter), Container<Sum<T>>>,
  37. }
  38.  
  39. impl<T: NodeData> Sum<T> {
  40.     fn at_zero(&self) -> T::Constant {
  41.         if self.minus {
  42.             -self.constant.clone()
  43.         } else {
  44.             self.constant.clone()
  45.         }
  46.     }
  47.  
  48.     fn adjust_constant(&self,
  49.                        constant: &T::Constant,
  50.                        env: &mut Environment<T>)
  51.                        -> Container<Sum<T>> {
  52.         env.make_sum(self.minus,
  53.                      &(if self.minus {
  54.                            -constant.clone()
  55.                        } else {
  56.                            constant.clone()
  57.                        }),
  58.                      &self.terms)
  59.     }
  60.  
  61.     fn conditional_negate(&self, env: &mut Environment<T>) -> Vec<Container<Product<T>>> {
  62.         if self.minus {
  63.             env.negate_products(&self.terms)
  64.         } else {
  65.             self.terms
  66.                 .iter()
  67.                 .map(|p| p.insert(env))
  68.                 .collect::<Vec<_>>()
  69.         }
  70.     }
  71. }
  72.  
  73. impl<T: NodeData> Environment<T> {
  74.     fn make_sum(&mut self,
  75.                 minus: bool,
  76.                 constant: &T::Constant,
  77.                 terms: &Vec<Container<Product<T>>>)
  78.                 -> Container<Sum<T>> {
  79.         if minus && terms.len() == 0 {
  80.             self.make_sum(false, &-constant.clone(), terms)
  81.         } else {
  82.             Sum::new(minus, constant.clone(), terms.clone()).insert(self)
  83.         }
  84.     }
  85.  
  86.     fn make_product(&mut self,
  87.                     coefficient: &T::Constant,
  88.                     powers: &Vec<Container<Power<T>>>)
  89.                     -> Option<Container<Product<T>>> {
  90.         if coefficient.clone() == T::Constant::zero() || powers.len() == 0 {
  91.             None
  92.         } else {
  93.             Some(Product::new(coefficient.clone(), powers.clone()).insert(self))
  94.         }
  95.     }
  96.  
  97.     fn make_power(&mut self,
  98.                   primitive: &Container<Primitive<T>>,
  99.                   exponent: &T::Exponent)
  100.                   -> Option<Container<Power<T>>> {
  101.         if exponent.clone() <= T::Exponent::zero() {
  102.             None
  103.         } else {
  104.             Some(Power::new(exponent.clone(), primitive.clone()).insert(self))
  105.         }
  106.     }
  107.  
  108.     fn make_system_variable(&mut self, index: &T::SystemVariable) -> Container<Primitive<T>> {
  109.         Primitive::SystemVariable(index.clone()).insert(self)
  110.     }
  111.  
  112.     fn make_parameter(&mut self, name: &T::Parameter) -> Container<Primitive<T>> {
  113.         Primitive::Parameter(name.clone()).insert(self)
  114.     }
  115.  
  116.     fn make_input(&mut self, index: &T::Input) -> Container<Primitive<T>> {
  117.         Primitive::Input(index.clone()).insert(self)
  118.     }
  119.  
  120.     fn make_sigmoid(&mut self, minus: bool, sum: &Container<Sum<T>>) -> Container<Primitive<T>> {
  121.         if sum.minus || (sum.terms.len() == 0 && sum.constant < T::Constant::zero()) {
  122.             let sum = sum.negate(self);
  123.             self.make_sigmoid(!minus, &sum)
  124.         } else {
  125.             Primitive::Sigmoid(minus, sum.clone()).insert(self)
  126.         }
  127.     }
  128.  
  129.     fn negate_products(&mut self,
  130.                        terms: &Vec<Container<Product<T>>>)
  131.                        -> Vec<Container<Product<T>>> {
  132.         terms.iter().map(|p| p.negate(self)).collect::<Vec<_>>()
  133.     }
  134.     fn number(&mut self, number: &T::Constant) -> Container<Sum<T>> {
  135.         self.make_sum(false, number, &vec![])
  136.     }
  137.  
  138.     fn add<A: Node<T>, B: Node<T>>(&mut self,
  139.                                    left: &Container<A>,
  140.                                    right: &Container<B>)
  141.                                    -> Container<Sum<T>> {
  142.         let left_sum = left.as_sum(self);
  143.         let right_sum = right.as_sum(self);
  144.         let constant = left_sum.at_zero() + right_sum.at_zero();
  145.         if left_sum.terms.len() == 0 {
  146.             return right_sum.adjust_constant(&constant, self);
  147.         }
  148.         if right_sum.terms.len() == 0 {
  149.             return left_sum.adjust_constant(&constant, self);
  150.         }
  151.  
  152.         // Begin port of weird logic
  153.         let left_terms = left_sum.conditional_negate(self);
  154.         let right_terms = right_sum.conditional_negate(self);
  155.         let left_length = left_terms.len();
  156.         let right_length = right_terms.len();
  157.  
  158.         let mut terms = Vec::with_capacity(left_length + right_length);
  159.         let mut left_index = 0;
  160.         let mut right_index = 0;
  161.  
  162.         while left_index < left_length && right_index < right_length {
  163.             let ref left_term = left_terms[left_index];
  164.             let ref right_term = right_terms[right_length];
  165.             match left_term.fuzzy_cmp(right_term.as_ref()) {
  166.                 Ordering::Less => {
  167.                     terms.push(left_term.clone());
  168.                     left_index += 1;
  169.                 }
  170.                 Ordering::Equal => {
  171.                     let coefficient = left_term.coefficient.clone() +
  172.                                       right_term.coefficient.clone();
  173.                     if coefficient != T::Constant::zero() {
  174.                         terms.push(self.make_product(&coefficient, &left_term.powers).unwrap());
  175.                     }
  176.                     left_index += 1;
  177.                     right_index += 1;
  178.                 }
  179.                 Ordering::Greater => {
  180.                     terms.push(right_term.clone());
  181.                     right_index += 1;
  182.                 }
  183.             }
  184.             if left_index == left_length {
  185.                 terms.extend_from_slice(&right_terms[..].split_at(right_index).1);
  186.             }
  187.             if right_index == right_length {
  188.                 terms.extend_from_slice(&left_terms[..].split_at(left_index).1);
  189.             }
  190.         }
  191.  
  192.         terms.shrink_to_fit(); // Why not.
  193.  
  194.         if terms.len() == 0 || terms[0].coefficient >= T::Constant::zero() {
  195.             self.make_sum(false, &constant, &terms)
  196.         } else {
  197.             let terms = self.negate_products(&terms);
  198.             self.make_sum(true, &constant, &terms)
  199.         }
  200.     }
  201.  
  202.     fn multiply<A: Node<T>, B: Node<T>>(&mut self,
  203.                                         left: &Container<A>,
  204.                                         right: &Container<B>)
  205.                                         -> Container<Sum<T>> {
  206.         let left_sum = left.as_sum(self);
  207.         let right_sum = right.as_sum(self);
  208.         let minus = left_sum.minus != right_sum.minus;
  209.  
  210.         let first_constant = left_sum.constant.clone() * right_sum.constant.clone();
  211.         let first = self.make_sum(minus, &first_constant, &vec![]);
  212.  
  213.         let mut outer_terms = Vec::with_capacity(right_sum.terms.len());
  214.         if left_sum.constant != T::Constant::zero() {
  215.             for term in &right_sum.terms {
  216.                 let outer_coefficient = left_sum.constant.clone() * term.coefficient.clone();
  217.                 let outer_term = self.make_product(&outer_coefficient, &term.powers).unwrap();
  218.                 if left_sum.constant < T::Constant::zero() {
  219.                     outer_terms.push(outer_term.negate(self));
  220.                 } else {
  221.                     outer_terms.push(outer_term);
  222.                 }
  223.             }
  224.         }
  225.         outer_terms.shrink_to_fit();
  226.         let outer = self.make_sum((left_sum.constant < T::Constant::zero()) != minus,
  227.                                   &T::Constant::zero(),
  228.                                   &outer_terms);
  229.  
  230.         // Should be possible to turn these into functions.
  231.         let mut inner_terms = Vec::with_capacity(left_sum.terms.len());
  232.         if right_sum.constant != T::Constant::zero() {
  233.             for term in &left_sum.terms {
  234.                 let inner_coefficient = right_sum.constant.clone() * term.coefficient.clone();
  235.                 let inner_term = self.make_product(&inner_coefficient, &term.powers).unwrap();
  236.                 if right_sum.constant < T::Constant::zero() {
  237.                     inner_terms.push(inner_term.negate(self));
  238.                 } else {
  239.                     inner_terms.push(inner_term);
  240.                 }
  241.             }
  242.         }
  243.         inner_terms.shrink_to_fit();
  244.         let inner = self.make_sum((right_sum.constant < T::Constant::zero()) != minus,
  245.                                   &T::Constant::zero(),
  246.                                   &inner_terms);
  247.  
  248.         let mut last = self.number(&T::Constant::zero());
  249.         for left_term in &left_sum.terms {
  250.             for right_term in &right_sum.terms {
  251.                 let combined_terms = self._combine_terms(&left_term.powers, &right_term.powers);
  252.                 let last_coefficient = left_term.coefficient.clone() *
  253.                                        right_term.coefficient.clone();
  254.                 let product = self.make_product(&last_coefficient, &combined_terms).unwrap();
  255.                 last = self.add(&last, &product);
  256.             }
  257.         }
  258.  
  259.         last = self.add(&last, &inner);
  260.         last = self.add(&last, &outer);
  261.         self.add(&last, &first)
  262.     }
  263.  
  264.     fn _combine_terms(&mut self,
  265.                       left: &Vec<Container<Power<T>>>,
  266.                       right: &Vec<Container<Power<T>>>)
  267.                       -> Vec<Container<Power<T>>> {
  268.         let left_length = left.len();
  269.         let right_length = right.len();
  270.  
  271.         let mut powers = Vec::with_capacity(left_length + right_length);
  272.         let mut left_index = 0;
  273.         let mut right_index = 0;
  274.  
  275.         while left_index < left_length && right_index < right_length {
  276.             let ref left_power = left[left_index];
  277.             let ref right_power = right[right_length];
  278.             match left_power.fuzzy_cmp(&right_power) {
  279.                 Ordering::Less => {
  280.                     powers.push(left_power.clone());
  281.                     left_index += 1;
  282.                 }
  283.                 Ordering::Equal => {
  284.                     let exponent = left_power.exponent.clone() + right_power.exponent.clone();
  285.                     powers.push(self.make_power(&left_power.primitive, &exponent).unwrap());
  286.                     left_index += 1;
  287.                     right_index += 1;
  288.                 }
  289.                 Ordering::Greater => {
  290.                     powers.push(right_power.clone());
  291.                     right_index += 1;
  292.                 }
  293.             }
  294.             if left_index == left_length {
  295.                 powers.extend_from_slice(&right[..].split_at(right_index).1);
  296.             }
  297.             if right_index == right_length {
  298.                 powers.extend_from_slice(&left[..].split_at(left_index).1);
  299.             }
  300.         }
  301.  
  302.         powers.shrink_to_fit(); // Why not.
  303.  
  304.         powers
  305.  
  306.     }
  307. }
  308.  
  309. trait Summable<T: NodeData> {
  310.     fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>>;
  311. }
  312.  
  313. trait Productable<T: NodeData> {
  314.     fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>>;
  315. }
  316.  
  317. fn _as_sum<T: NodeData>(productable: &Productable<T>,
  318.                         env: &mut Environment<T>)
  319.                         -> Container<Sum<T>> {
  320.     let product = productable.as_product(env);
  321.     if product.coefficient < T::Constant::zero() {
  322.         let product = product.negate(env);
  323.         env.make_sum(true, &T::Constant::zero(), &vec![product])
  324.     } else {
  325.         env.make_sum(false, &T::Constant::zero(), &vec![product])
  326.     }
  327. }
  328.  
  329. impl<T: NodeData> Summable<T> for Product<T> {
  330.     fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
  331.         _as_sum(self, env)
  332.     }
  333. }
  334.  
  335. impl<T: NodeData> Summable<T> for Power<T> {
  336.     fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
  337.         _as_sum(self, env)
  338.     }
  339. }
  340.  
  341. impl<T: NodeData> Summable<T> for Primitive<T> {
  342.     fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
  343.         _as_sum(self, env)
  344.     }
  345. }
  346.  
  347. trait Node<T: NodeData>: Summable<T> {
  348.     fn partial_derivative(&self,
  349.                           variable: &T::Parameter,
  350.                           env: &mut Environment<T>)
  351.                           -> Container<Sum<T>>;
  352.  
  353.     fn fancy_clone(&self, env: &mut Environment<T>) -> Self;
  354.  
  355.     fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Self>
  356.        where Self: Sized;
  357.  
  358.    fn insert(&self, env: &mut Environment<T>) -> Container<Self>
  359.        where Self: Sized + Clone + Eq + Hash
  360.    {
  361.        if !self.get_storage(env).contains_key(self) {
  362.            let key = self.fancy_clone(env);
  363.            let value = key.clone();
  364.            self.get_storage(env).insert(key, container(value));
  365.        }
  366.        match self.get_storage(env).get(self) {
  367.                Some(node) => node,
  368.                None => unreachable!(),
  369.            }
  370.            .clone()
  371.    }
  372. }
  373.  
  374. impl<T: NodeData> Summable<T> for Sum<T> {
  375.    fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
  376.        self.insert(env)
  377.    }
  378. }
  379.  
  380. impl<T: NodeData> Node<T> for Sum<T> {
  381.    fn partial_derivative(&self,
  382.                          variable: &T::Parameter,
  383.                          env: &mut Environment<T>)
  384.                          -> Container<Sum<T>> {
  385.        if !env.sum_derivative.contains_key(&(self.clone(), variable.clone())) {
  386.            let key = (self.fancy_clone(env), variable.clone());
  387.            let mut result = env.number(&T::Constant::zero());
  388.            for term in &self.terms {
  389.                let partial = term.partial_derivative(variable, env);
  390.                result = env.add(&result, &partial);
  391.            }
  392.            if self.minus {
  393.                result = result.negate(env);
  394.            }
  395.            env.sum_derivative.insert(key, result);
  396.        }
  397.        match env.sum_derivative.get(&(self.clone(), variable.clone())) {
  398.            Some(derivative) => derivative.clone(),
  399.            None => unreachable!(),
  400.        }
  401.    }
  402.  
  403.    fn fancy_clone(&self, env: &mut Environment<T>) -> Sum<T> {
  404.        Sum {
  405.            pre_hash: self.pre_hash.clone(),
  406.            minus: self.minus.clone(),
  407.            constant: self.constant.clone(),
  408.            terms: self.terms
  409.                .iter()
  410.                .map(|p| p.insert(env))
  411.                .collect::<Vec<_>>(),
  412.        }
  413.    }
  414.  
  415.    fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Sum<T>> {
  416.         &mut env.sum
  417.     }
  418. }
  419.  
  420. impl<T: NodeData> Productable<T> for Product<T> {
  421.     fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>> {
  422.         self.insert(env)
  423.     }
  424. }
  425.  
  426. impl<T: NodeData> Node<T> for Product<T> {
  427.     fn partial_derivative(&self,
  428.                           variable: &T::Parameter,
  429.                           env: &mut Environment<T>)
  430.                           -> Container<Sum<T>> {
  431.         if !env.product_derivative.contains_key(&(self.clone(), variable.clone())) {
  432.             let key = (self.fancy_clone(env), variable.clone());
  433.             let result = if self.powers.len() == 1 {
  434.                 self.powers[0].partial_derivative(variable, env)
  435.             } else {
  436.                 let mut result = env.number(&T::Constant::zero());
  437.                 unimplemented!()
  438.                 /*for (index, focus) in self.powers.iter().enumerate() {
  439.                     let remainder = env.make_product(T::Constant::one(), )
  440.                 }*/
  441.             };
  442.             env.product_derivative.insert(key, result);
  443.         }
  444.         match env.product_derivative.get(&(self.clone(), variable.clone())) {
  445.             Some(derivative) => derivative.clone(),
  446.             None => unreachable!(),
  447.         }
  448.     }
  449.  
  450.     fn fancy_clone(&self, env: &mut Environment<T>) -> Product<T> {
  451.         Product {
  452.             pre_hash: self.pre_hash.clone(),
  453.             coefficient: self.coefficient.clone(),
  454.             powers: self.powers
  455.                 .iter()
  456.                 .map(|p| p.insert(env))
  457.                 .collect::<Vec<_>>(),
  458.         }
  459.     }
  460.  
  461.     fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Product<T>> {
  462.        &mut env.product
  463.    }
  464. }
  465.  
  466. impl<T: NodeData> Productable<T> for Power<T> {
  467.    fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>> {
  468.        let power = self.insert(env);
  469.        env.make_product(&T::Constant::one(), &vec![power]).unwrap()
  470.    }
  471. }
  472.  
  473. impl<T: NodeData> Node<T> for Power<T> {
  474.    fn partial_derivative(&self,
  475.                          variable: &T::Parameter,
  476.                          env: &mut Environment<T>)
  477.                          -> Container<Sum<T>> {
  478.        unimplemented!();
  479.    }
  480.  
  481.    fn fancy_clone(&self, env: &mut Environment<T>) -> Power<T> {
  482.        Power {
  483.            exponent: self.exponent.clone(),
  484.            primitive: self.primitive.insert(env),
  485.        }
  486.    }
  487.  
  488.    fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Power<T>> {
  489.         &mut env.power
  490.     }
  491. }
  492.  
  493. impl<T: NodeData> Primitive<T> {
  494.     fn as_power(&self, env: &mut Environment<T>) -> Container<Power<T>> {
  495.         let primitive = self.insert(env);
  496.         env.make_power(&primitive, &T::Exponent::one()).unwrap()
  497.     }
  498. }
  499.  
  500. impl<T: NodeData> Productable<T> for Primitive<T> {
  501.     fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>> {
  502.         match self {
  503.             &Primitive::Sigmoid(true, ref sum) => {
  504.                 let sigmoid = env.make_sigmoid(false, sum);
  505.                 let power = sigmoid.as_power(env);
  506.                 env.make_product(&-T::Constant::one(), &vec![power]).unwrap()
  507.             }
  508.             _ => {
  509.                 let power = self.as_power(env);
  510.                 power.as_product(env)
  511.             }
  512.         }
  513.     }
  514. }
  515.  
  516. impl<T: NodeData> Node<T> for Primitive<T> {
  517.     fn partial_derivative(&self,
  518.                           variable: &T::Parameter,
  519.                           env: &mut Environment<T>)
  520.                           -> Container<Sum<T>> {
  521.         unimplemented!();
  522.     }
  523.  
  524.     fn fancy_clone(&self, env: &mut Environment<T>) -> Primitive<T> {
  525.         match self {
  526.             &Primitive::Sigmoid(minus, ref a) => Primitive::Sigmoid(minus, a.insert(env)),
  527.             _ => self.clone(),
  528.         }
  529.     }
  530.  
  531.     fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Primitive<T>> {
  532.        &mut env.primitive
  533.    }
  534. }
  535.  
  536. #[cfg(test)]
  537. mod tests {
  538.    #[test]
  539.    fn it_works() {}
  540. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement