Advertisement
mwchase

LC:NN upload 2a

Mar 12th, 2017
140
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Rust 16.64 KB | None | 0 0
  1. extern crate num;
  2.  
  3. use std::cmp::Ordering;
  4. use std::collections::HashMap;
  5. use std::collections::hash_map::{DefaultHasher, Entry};
  6. use std::hash::{Hash, Hasher};
  7. use std::sync::Arc;
  8.  
  9. use num::{Integer, One, Signed, Zero};
  10.  
  11. trait NodeData {
  12.     type Constant: Hash + Eq + Ord + Clone + Signed;
  13.     type Coefficient: Hash + Eq + Ord + Clone + Signed;
  14.     type Exponent: Hash + Clone + Integer;
  15.     type Input: Hash + Eq + Clone;
  16.     type SystemVariable: Hash + Eq + Clone;
  17.     type Parameter: Hash + Eq + Clone;
  18. }
  19.  
  20. struct Sum<T: NodeData> {
  21.     pre_hash: u64,
  22.     minus: bool,
  23.     constant: T::Constant,
  24.     terms: Vec<Arc<Product<T>>>,
  25. }
  26.  
  27. impl<T: NodeData> Sum<T> {
  28.     fn negate(&self, env: &mut Environment<T>) -> Arc<Sum<T>> {
  29.         env.make_sum(!self.minus, &self.constant, &self.terms)
  30.     }
  31. }
  32.  
  33. struct Product<T: NodeData> {
  34.     pre_hash: u64,
  35.     coefficient: T::Coefficient,
  36.     powers: Vec<Arc<Power<T>>>,
  37. }
  38.  
  39. impl<T: NodeData> Product<T> {
  40.     fn negate(&self, env: &mut Environment<T>) -> Arc<Product<T>> {
  41.         env.make_product(&-self.coefficient.clone(), &self.powers).unwrap()
  42.     }
  43. }
  44.  
  45. struct Power<T: NodeData> {
  46.     exponent: T::Exponent,
  47.     primitive: Arc<Primitive<T>>,
  48. }
  49.  
  50. enum Primitive<T: NodeData> {
  51.     Input(T::Input),
  52.     SystemVariable(T::SystemVariable),
  53.     Parameter(T::Parameter),
  54.     Sigmoid(bool, Arc<Sum<T>>),
  55. }
  56.  
  57. type SelfMap<T> = HashMap<T, Arc<T>>;
  58.  
  59. struct Environment<T: NodeData> {
  60.     sum: SelfMap<Sum<T>>,
  61.     product: SelfMap<Product<T>>,
  62.     power: SelfMap<Power<T>>,
  63.     primitive: SelfMap<Primitive<T>>, //I could put in the maps for the various partial
  64.                                       //derivatives, but that's a
  65.                                       //non-vital optimization.
  66. }
  67.  
  68. impl<T: NodeData> Environment<T> {
  69.     fn make_sum(&mut self,
  70.                 minus: bool,
  71.                 constant: &T::Constant,
  72.                 terms: &Vec<Arc<Product<T>>>)
  73.                 -> Arc<Sum<T>> {
  74.         if minus && terms.len() == 0 {
  75.             self.make_sum(false, &-constant.clone(), terms)
  76.         } else {
  77.             Sum::new(minus, constant.clone(), terms.clone()).insert(self)
  78.         }
  79.     }
  80.  
  81.     fn make_product(&mut self,
  82.                     coefficient: &T::Coefficient,
  83.                     powers: &Vec<Arc<Power<T>>>)
  84.                     -> Option<Arc<Product<T>>> {
  85.         if coefficient.clone() == T::Coefficient::zero() || powers.len() == 0 {
  86.             None
  87.         } else {
  88.             Some(Product::new(coefficient.clone(), powers.clone()).insert(self))
  89.         }
  90.     }
  91.  
  92.     fn make_power(&mut self,
  93.                   primitive: &Arc<Primitive<T>>,
  94.                   exponent: &T::Exponent)
  95.                   -> Option<Arc<Power<T>>> {
  96.         if exponent.clone() <= T::Exponent::zero() {
  97.             None
  98.         } else {
  99.             Some(Power::new(exponent.clone(), primitive.clone()).insert(self))
  100.         }
  101.     }
  102.  
  103.     fn make_system_variable(&mut self, index: &T::SystemVariable) -> Arc<Primitive<T>> {
  104.         Primitive::SystemVariable(index.clone()).insert(self)
  105.     }
  106.  
  107.     fn make_parameter(&mut self, name: &T::Parameter) -> Arc<Primitive<T>> {
  108.         Primitive::Parameter(name.clone()).insert(self)
  109.     }
  110.  
  111.     fn make_input(&mut self, index: &T::Input) -> Arc<Primitive<T>> {
  112.         Primitive::Input(index.clone()).insert(self)
  113.     }
  114.  
  115.     fn make_sigmoid(&mut self, minus: bool, sum: &Arc<Sum<T>>) -> Arc<Primitive<T>> {
  116.         if sum.minus || (sum.terms.len() == 0 && sum.constant < T::Constant::zero()) {
  117.             let sum = sum.negate(self);
  118.             self.make_sigmoid(!minus, &sum)
  119.         } else {
  120.             Primitive::Sigmoid(minus, sum.clone()).insert(self)
  121.         }
  122.     }
  123.     //fn number(&mut self, &T::Constant)
  124. }
  125.  
  126. trait Summable<T: NodeData> {
  127.     fn as_sum(&self, env: &mut Environment<T>) -> Arc<Sum<T>>;
  128. }
  129.  
  130. trait Productable<T: NodeData> {
  131.     fn as_product(&self, env: &mut Environment<T>) -> Arc<Product<T>>;
  132. }
  133.  
  134. fn _as_sum<T: NodeData>(productable: &Productable<T>, env: &mut Environment<T>) -> Arc<Sum<T>> {
  135.     let product = productable.as_product(env);
  136.     if product.coefficient < T::Coefficient::zero() {
  137.         let product = product.negate(env);
  138.         env.make_sum(true, &T::Constant::zero(), &vec![product])
  139.     } else {
  140.         env.make_sum(false, &T::Constant::zero(), &vec![product])
  141.     }
  142. }
  143.  
  144. impl<T: NodeData> Summable<T> for Product<T> {
  145.     fn as_sum(&self, env: &mut Environment<T>) -> Arc<Sum<T>> {
  146.         _as_sum(self, env)
  147.     }
  148. }
  149.  
  150. impl<T: NodeData> Summable<T> for Power<T> {
  151.     fn as_sum(&self, env: &mut Environment<T>) -> Arc<Sum<T>> {
  152.         _as_sum(self, env)
  153.     }
  154. }
  155.  
  156. impl<T: NodeData> Summable<T> for Primitive<T> {
  157.     fn as_sum(&self, env: &mut Environment<T>) -> Arc<Sum<T>> {
  158.         _as_sum(self, env)
  159.     }
  160. }
  161.  
  162. trait Node<T: NodeData>: Summable<T> {
  163.     fn partial_derivative(&self, variable: &T::Parameter, env: &mut Environment<T>) -> Arc<Sum<T>>;
  164.  
  165.     fn fancy_clone(&self, env: &mut Environment<T>) -> Self;
  166.  
  167.     fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Self>
  168.        where Self: Sized;
  169.  
  170.    fn insert(&self, env: &mut Environment<T>) -> Arc<Self>
  171.        where Self: Sized + Clone + Eq + Hash
  172.    {
  173.        if !self.get_storage(env).contains_key(self) {
  174.            let key = self.fancy_clone(env);
  175.            let value = key.clone();
  176.            self.get_storage(env).insert(key, Arc::new(value));
  177.        }
  178.        match self.get_storage(env).get(self) {
  179.                Some(node) => node,
  180.                None => unreachable!(),
  181.            }
  182.            .clone()
  183.    }
  184. }
  185.  
  186. impl<T: NodeData> Summable<T> for Sum<T> {
  187.    fn as_sum(&self, env: &mut Environment<T>) -> Arc<Sum<T>> {
  188.        self.insert(env)
  189.    }
  190. }
  191.  
  192. impl<T: NodeData> Node<T> for Sum<T> {
  193.    fn partial_derivative(&self, variable: &T::Parameter, env: &mut Environment<T>) -> Arc<Sum<T>> {
  194.        unimplemented!();
  195.    }
  196.  
  197.    fn fancy_clone(&self, env: &mut Environment<T>) -> Sum<T> {
  198.        Sum {
  199.            pre_hash: self.pre_hash.clone(),
  200.            minus: self.minus.clone(),
  201.            constant: self.constant.clone(),
  202.            terms: self.terms
  203.                .iter()
  204.                .map(|p| p.insert(env))
  205.                .collect::<Vec<_>>(),
  206.        }
  207.    }
  208.  
  209.    fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Sum<T>> {
  210.         &mut env.sum
  211.     }
  212. }
  213.  
  214. impl<T: NodeData> Productable<T> for Product<T> {
  215.     fn as_product(&self, env: &mut Environment<T>) -> Arc<Product<T>> {
  216.         self.insert(env)
  217.     }
  218. }
  219.  
  220. impl<T: NodeData> Node<T> for Product<T> {
  221.     fn partial_derivative(&self, variable: &T::Parameter, env: &mut Environment<T>) -> Arc<Sum<T>> {
  222.         unimplemented!();
  223.     }
  224.  
  225.     fn fancy_clone(&self, env: &mut Environment<T>) -> Product<T> {
  226.         Product {
  227.             pre_hash: self.pre_hash.clone(),
  228.             coefficient: self.coefficient.clone(),
  229.             powers: self.powers
  230.                 .iter()
  231.                 .map(|p| p.insert(env))
  232.                 .collect::<Vec<_>>(),
  233.         }
  234.     }
  235.  
  236.     fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Product<T>> {
  237.        &mut env.product
  238.    }
  239. }
  240.  
  241. impl<T: NodeData> Productable<T> for Power<T> {
  242.    fn as_product(&self, env: &mut Environment<T>) -> Arc<Product<T>> {
  243.        let power = self.insert(env);
  244.        env.make_product(&T::Coefficient::one(), &vec![power]).unwrap()
  245.    }
  246. }
  247.  
  248. impl<T: NodeData> Node<T> for Power<T> {
  249.    fn partial_derivative(&self, variable: &T::Parameter, env: &mut Environment<T>) -> Arc<Sum<T>> {
  250.        unimplemented!();
  251.    }
  252.  
  253.    fn fancy_clone(&self, env: &mut Environment<T>) -> Power<T> {
  254.        Power {
  255.            exponent: self.exponent.clone(),
  256.            primitive: self.primitive.insert(env),
  257.        }
  258.    }
  259.  
  260.    fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Power<T>> {
  261.         &mut env.power
  262.     }
  263. }
  264.  
  265. impl<T: NodeData> Primitive<T> {
  266.     fn as_power(&self, env: &mut Environment<T>) -> Arc<Power<T>> {
  267.         let primitive = self.insert(env);
  268.         env.make_power(&primitive, &T::Exponent::one()).unwrap()
  269.     }
  270. }
  271.  
  272. impl<T: NodeData> Productable<T> for Primitive<T> {
  273.     fn as_product(&self, env: &mut Environment<T>) -> Arc<Product<T>> {
  274.         match self {
  275.             &Primitive::Sigmoid(true, ref sum) => {
  276.                 let sigmoid = env.make_sigmoid(false, sum);
  277.                 let power = sigmoid.as_power(env);
  278.                 env.make_product(&-T::Coefficient::one(), &vec![power]).unwrap()
  279.             }
  280.             _ => {
  281.                 let power = self.as_power(env);
  282.                 power.as_product(env)
  283.             }
  284.         }
  285.     }
  286. }
  287.  
  288. impl<T: NodeData> Node<T> for Primitive<T> {
  289.     fn partial_derivative(&self, variable: &T::Parameter, env: &mut Environment<T>) -> Arc<Sum<T>> {
  290.         unimplemented!();
  291.     }
  292.  
  293.     fn fancy_clone(&self, env: &mut Environment<T>) -> Primitive<T> {
  294.         match self {
  295.             &Primitive::Sigmoid(minus, ref a) => Primitive::Sigmoid(minus, a.insert(env)),
  296.             _ => self.clone(),
  297.         }
  298.     }
  299.  
  300.     fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Primitive<T>> {
  301.        &mut env.primitive
  302.    }
  303. }
  304.  
  305. impl<T: NodeData> Sum<T> {
  306.    fn new(minus: bool, constant: T::Constant, terms: Vec<Arc<Product<T>>>) -> Sum<T> {
  307.        let mut s = DefaultHasher::new();
  308.        minus.hash(&mut s);
  309.        constant.hash(&mut s);
  310.        terms.hash(&mut s);
  311.        Sum {
  312.            pre_hash: s.finish(),
  313.            minus: minus,
  314.            constant: constant,
  315.            terms: terms,
  316.        }
  317.    }
  318. }
  319.  
  320. impl<T: NodeData> Product<T> {
  321.    fn new(coefficient: T::Coefficient, powers: Vec<Arc<Power<T>>>) -> Product<T> {
  322.        let mut s = DefaultHasher::new();
  323.        coefficient.hash(&mut s);
  324.        powers.hash(&mut s);
  325.        Product {
  326.            pre_hash: s.finish(),
  327.            coefficient: coefficient,
  328.            powers: powers,
  329.        }
  330.    }
  331. }
  332.  
  333. impl<T: NodeData> Power<T> {
  334.    fn new(exponent: T::Exponent, primitive: Arc<Primitive<T>>) -> Power<T> {
  335.        Power {
  336.            exponent: exponent,
  337.            primitive: primitive,
  338.        }
  339.    }
  340. }
  341.  
  342. impl<T: NodeData> Ord for Sum<T> {
  343.    fn cmp(&self, other: &Sum<T>) -> Ordering {
  344.        let cmp_result = self.constant.cmp(&other.constant);
  345.        match cmp_result {
  346.            Ordering::Equal => {
  347.                for (self_term, other_term) in
  348.                    self.terms
  349.                        .clone()
  350.                        .into_iter()
  351.                        .zip(other.terms.clone()) {
  352.                    let cmp_result = self_term.cmp(&other_term);
  353.                    match cmp_result {
  354.                        Ordering::Equal => (),
  355.                        _ => return cmp_result,
  356.                    }
  357.                }
  358.                let cmp_result = self.terms.len().cmp(&other.terms.len());
  359.                match cmp_result {
  360.                    Ordering::Equal => self.minus.cmp(&other.minus),
  361.                    _ => cmp_result,
  362.                }
  363.            }
  364.            _ => cmp_result,
  365.        }
  366.    }
  367. }
  368.  
  369. impl<T: NodeData> Ord for Product<T> {
  370.    fn cmp(&self, other: &Product<T>) -> Ordering {
  371.        let cmp_result = self.fuzzy_cmp(other);
  372.        match cmp_result {
  373.            Ordering::Equal => self.coefficient.cmp(&other.coefficient),
  374.            _ => cmp_result,
  375.        }
  376.    }
  377. }
  378.  
  379. impl<T: NodeData> Product<T> {
  380.    fn fuzzy_cmp(&self, other: &Product<T>) -> Ordering {
  381.        for (self_power, other_power) in
  382.            self.powers
  383.                .clone()
  384.                .into_iter()
  385.                .zip(other.powers.clone()) {
  386.            let cmp_result = self_power.cmp(&other_power);
  387.            match cmp_result {
  388.                Ordering::Equal => (),
  389.                _ => return cmp_result,
  390.            }
  391.        }
  392.        self.powers.len().cmp(&other.powers.len())
  393.    }
  394. }
  395.  
  396. impl<T: NodeData> Ord for Power<T> {
  397.    fn cmp(&self, other: &Power<T>) -> Ordering {
  398.        let cmp_result = self.primitive.cmp(&other.primitive);
  399.        match cmp_result {
  400.            Ordering::Equal => self.exponent.cmp(&other.exponent),
  401.            _ => cmp_result,
  402.        }
  403.    }
  404. }
  405.  
  406. impl<T: NodeData> Ord for Primitive<T> {
  407.    fn cmp(&self, other: &Primitive<T>) -> Ordering {
  408.        unimplemented!()
  409.    }
  410. }
  411.  
  412. impl<T: NodeData> PartialOrd for Sum<T> {
  413.    fn partial_cmp(&self, other: &Sum<T>) -> Option<Ordering> {
  414.        Some(self.cmp(other))
  415.    }
  416. }
  417. impl<T: NodeData> PartialOrd for Product<T> {
  418.    fn partial_cmp(&self, other: &Product<T>) -> Option<Ordering> {
  419.        Some(self.cmp(other))
  420.    }
  421. }
  422. impl<T: NodeData> PartialOrd for Power<T> {
  423.    fn partial_cmp(&self, other: &Power<T>) -> Option<Ordering> {
  424.        Some(self.cmp(other))
  425.    }
  426. }
  427. impl<T: NodeData> PartialOrd for Primitive<T> {
  428.    fn partial_cmp(&self, other: &Primitive<T>) -> Option<Ordering> {
  429.        Some(self.cmp(other))
  430.    }
  431. }
  432.  
  433. impl<T: NodeData> Hash for Sum<T> {
  434.    fn hash<H: Hasher>(&self, state: &mut H) {
  435.        self.pre_hash.hash(state);
  436.    }
  437. }
  438.  
  439. impl<T: NodeData> Hash for Product<T> {
  440.    fn hash<H: Hasher>(&self, state: &mut H) {
  441.        self.pre_hash.hash(state);
  442.    }
  443. }
  444.  
  445. impl<T: NodeData> Hash for Power<T> {
  446.    fn hash<H: Hasher>(&self, state: &mut H) {
  447.        self.exponent.hash(state);
  448.        self.primitive.hash(state);
  449.    }
  450. }
  451.  
  452. impl<T: NodeData> Hash for Primitive<T> {
  453.    fn hash<H: Hasher>(&self, state: &mut H) {
  454.        match self {
  455.            &Primitive::Input(ref a) => {
  456.                "Input".hash(state);
  457.                a.hash(state)
  458.            }
  459.            &Primitive::SystemVariable(ref a) => {
  460.                "SystemVariable".hash(state);
  461.                a.hash(state)
  462.            }
  463.            &Primitive::Parameter(ref a) => {
  464.                "Parameter".hash(state);
  465.                a.hash(state)
  466.            }
  467.            &Primitive::Sigmoid(minus, ref a) => {
  468.                "Sigmoid".hash(state);
  469.                minus.hash(state);
  470.                a.hash(state)
  471.            }
  472.        }
  473.    }
  474. }
  475.  
  476. impl<T: NodeData> PartialEq for Sum<T> {
  477.    fn eq(&self, other: &Sum<T>) -> bool {
  478.        self.minus == other.minus && self.constant == other.constant && self.terms == other.terms
  479.    }
  480. }
  481.  
  482. impl<T: NodeData> PartialEq for Product<T> {
  483.    fn eq(&self, other: &Product<T>) -> bool {
  484.        self.coefficient == other.coefficient && self.powers == other.powers
  485.    }
  486. }
  487.  
  488. impl<T: NodeData> PartialEq for Power<T> {
  489.    fn eq(&self, other: &Power<T>) -> bool {
  490.        self.exponent == other.exponent && self.primitive == other.primitive
  491.    }
  492. }
  493.  
  494. impl<T: NodeData> PartialEq for Primitive<T> {
  495.    fn eq(&self, other: &Primitive<T>) -> bool {
  496.        match (self, other) {
  497.            (&Primitive::Input(ref a), &Primitive::Input(ref b)) => a == b,
  498.            (&Primitive::SystemVariable(ref a), &Primitive::SystemVariable(ref b)) => a == b,
  499.            (&Primitive::Parameter(ref a), &Primitive::Parameter(ref b)) => a == b,
  500.            (&Primitive::Sigmoid(minus_a, ref a), &Primitive::Sigmoid(minus_b, ref b)) => {
  501.                minus_a == minus_b && a == b
  502.            }
  503.            _ => false,
  504.        }
  505.    }
  506. }
  507.  
  508. impl<T: NodeData> Eq for Sum<T> {}
  509. impl<T: NodeData> Eq for Product<T> {}
  510. impl<T: NodeData> Eq for Power<T> {}
  511. impl<T: NodeData> Eq for Primitive<T> {}
  512.  
  513. impl<T: NodeData> Clone for Sum<T> {
  514.    fn clone(&self) -> Sum<T> {
  515.        Sum {
  516.            pre_hash: self.pre_hash.clone(),
  517.            minus: self.minus.clone(),
  518.            constant: self.constant.clone(),
  519.            terms: self.terms.clone(),
  520.        }
  521.    }
  522. }
  523.  
  524. impl<T: NodeData> Clone for Product<T> {
  525.    fn clone(&self) -> Product<T> {
  526.        Product {
  527.            pre_hash: self.pre_hash.clone(),
  528.            coefficient: self.coefficient.clone(),
  529.            powers: self.powers.clone(),
  530.        }
  531.    }
  532. }
  533.  
  534. impl<T: NodeData> Clone for Power<T> {
  535.    fn clone(&self) -> Power<T> {
  536.        Power {
  537.            exponent: self.exponent.clone(),
  538.            primitive: self.primitive.clone(),
  539.        }
  540.    }
  541. }
  542.  
  543. impl<T: NodeData> Clone for Primitive<T> {
  544.    fn clone(&self) -> Primitive<T> {
  545.        match self {
  546.            &Primitive::Input(ref a) => Primitive::Input(a.clone()),
  547.            &Primitive::SystemVariable(ref a) => Primitive::SystemVariable(a.clone()),
  548.            &Primitive::Parameter(ref a) => Primitive::Parameter(a.clone()),
  549.            &Primitive::Sigmoid(minus, ref a) => Primitive::Sigmoid(minus, a.clone()),
  550.        }
  551.    }
  552. }
  553.  
  554. #[cfg(test)]
  555. mod tests {
  556.    #[test]
  557.    fn it_works() {}
  558. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement