Advertisement
mwchase

LC:NN upload 4 - diff

Mar 26th, 2017
231
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Diff 37.03 KB | None | 0 0
  1. diff --git a/rust/adagio/src/eval.rs b/rust/adagio/src/eval.rs
  2. new file mode 100644
  3. diff --git a/rust/adagio/src/lib.rs b/rust/adagio/src/lib.rs
  4. --- a/rust/adagio/src/lib.rs
  5. +++ b/rust/adagio/src/lib.rs
  6. @@ -2,30 +2,37 @@
  7.  
  8.  use std::cmp::Ordering;
  9.  use std::collections::HashMap;
  10. -use std::hash::{Hash, Hasher};
  11. +use std::hash::Hash;
  12.  
  13. -use num::{One, Zero};
  14. -
  15. +mod node_conversions;
  16.  mod node_defs;
  17.  mod node_ordering;
  18.  
  19. -use node_defs::{Container, container, NodeData, Sum, Product, Power, Primitive};
  20. +use node_conversions::Summable;
  21. +use node_defs::{Container, container, FancyClone, NodeData, Sum, Product, Power, Primitive};
  22.  
  23. +/// Negates a Sum node.
  24.  impl<T: NodeData> Sum<T> {
  25.      fn negate(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
  26. -        env.make_sum(!self.minus, &self.constant, &self.terms)
  27. +        env.make_sum(!self.minus,
  28. +                     self.constant.clone(),
  29. +                     self.terms.iter().cloned())
  30.      }
  31.  }
  32.  
  33. +/// Negates a Product node.
  34.  impl<T: NodeData> Product<T> {
  35.      fn negate(&self, env: &mut Environment<T>) -> Container<Product<T>> {
  36. -        env.make_product(&-self.coefficient.clone(), &self.powers).unwrap()
  37. +        env.secret_make_product(-self.coefficient.clone(), self.powers.clone())
  38.      }
  39.  }
  40.  
  41. +/// Convenience type, maps an instance of a type to an Arc containing a clone of the instance.
  42.  type SelfMap<T> = HashMap<T, Container<T>>;
  43.  
  44. -struct Environment<T: NodeData> {
  45. +/// A struct that contains canonicalization maps for every node type, as well as maps to
  46. +/// memoize partial derivatives.
  47. +pub struct Environment<T: NodeData> {
  48.      sum: SelfMap<Sum<T>>,
  49.      product: SelfMap<Product<T>>,
  50.      power: SelfMap<Power<T>>,
  51. @@ -37,6 +44,7 @@
  52.  }
  53.  
  54.  impl<T: NodeData> Sum<T> {
  55. +    /// Returns the value of the sum, if every term were 0.
  56.      fn at_zero(&self) -> T::Constant {
  57.          if self.minus {
  58.              -self.constant.clone()
  59. @@ -45,108 +53,143 @@
  60.          }
  61.      }
  62.  
  63. +    /// Returns a Sum that has the same shape as the sum, but equal to constant at zero.
  64.      fn adjust_constant(&self,
  65. -                       constant: &T::Constant,
  66. +                       constant: T::Constant,
  67.                         env: &mut Environment<T>)
  68.                         -> Container<Sum<T>> {
  69.          env.make_sum(self.minus,
  70. -                     &(if self.minus {
  71. -                           -constant.clone()
  72. -                       } else {
  73. -                           constant.clone()
  74. -                       }),
  75. -                     &self.terms)
  76. +                     (if self.minus { -constant } else { constant }),
  77. +                     self.terms.iter().cloned())
  78.      }
  79.  
  80. +    /// Returns the terms of the sum, negated if the sum was negated.
  81.      fn conditional_negate(&self, env: &mut Environment<T>) -> Vec<Container<Product<T>>> {
  82.          if self.minus {
  83. -            env.negate_products(&self.terms)
  84. +            env.negate_products(self.terms.iter().cloned())
  85.          } else {
  86.              self.terms
  87.                  .iter()
  88.                  .map(|p| p.insert(env))
  89. -                .collect::<Vec<_>>()
  90. +                .collect()
  91.          }
  92.      }
  93.  }
  94.  
  95.  impl<T: NodeData> Environment<T> {
  96. -    fn make_sum(&mut self,
  97. -                minus: bool,
  98. -                constant: &T::Constant,
  99. -                terms: &Vec<Container<Product<T>>>)
  100. -                -> Container<Sum<T>> {
  101. +    /// Makes a canonicalized Sum from the given components.
  102. +    pub fn make_sum<C, N: Into<T::Constant>>(&mut self,
  103. +                                             minus: bool,
  104. +                                             constant: N,
  105. +                                             terms: C)
  106. +                                             -> Container<Sum<T>>
  107. +        where C: IntoIterator<Item = Container<Product<T>>>
  108. +    {
  109. +        let terms = terms.into_iter().collect::<Vec<_>>();
  110. +        let mut minus = minus;
  111. +        let mut constant = constant.into();
  112.          if minus && terms.len() == 0 {
  113. -            self.make_sum(false, &-constant.clone(), terms)
  114. -        } else {
  115. -            Sum::new(minus, constant.clone(), terms.clone()).insert(self)
  116. +            minus = false;
  117. +            constant = -constant;
  118.          }
  119. +        Sum::new(minus, constant, terms).insert(self)
  120. +    }
  121. +
  122. +    fn secret_make_product<N: Into<T::Coefficient>>(&mut self,
  123. +                                                    coefficient: N,
  124. +                                                    powers: Vec<Container<Power<T>>>)
  125. +                                                    -> Container<Product<T>> {
  126. +        Product::new(coefficient.into(), powers).insert(self)
  127.      }
  128.  
  129. -    fn make_product(&mut self,
  130. -                    coefficient: &T::Constant,
  131. -                    powers: &Vec<Container<Power<T>>>)
  132. -                    -> Option<Container<Product<T>>> {
  133. -        if coefficient.clone() == T::Constant::zero() || powers.len() == 0 {
  134. +    /// Makes a canonicalized Product from the given components. Because some inputs create
  135. +    /// invalid products, this returns an Option.
  136. +    pub fn make_product<C, N: Into<T::Coefficient>>(&mut self,
  137. +                                                    coefficient: N,
  138. +                                                    powers: C)
  139. +                                                    -> Option<Container<Product<T>>>
  140. +        where C: IntoIterator<Item = Container<Power<T>>>
  141. +    {
  142. +        let coefficient = coefficient.into();
  143. +        if coefficient != 0.into() {
  144. +            let powers = powers.into_iter().collect::<Vec<_>>();
  145. +            if powers.len() != 0 {
  146. +                return Some(self.secret_make_product(coefficient, powers));
  147. +            }
  148. +        }
  149. +        None
  150. +    }
  151. +
  152. +    fn secret_make_power<N: Into<T::Exponent>>(&mut self,
  153. +                                               primitive: Container<Primitive<T>>,
  154. +                                               exponent: N)
  155. +                                               -> Container<Power<T>> {
  156. +        Power::new(exponent.into(), primitive).insert(self)
  157. +    }
  158. +
  159. +    /// Makes a canonicalized Power from the given components. Because some inputs create
  160. +    /// invalid powers, this returns an Option.
  161. +    pub fn make_power<N: Into<T::Exponent>>(&mut self,
  162. +                                            primitive: Container<Primitive<T>>,
  163. +                                            exponent: N)
  164. +                                            -> Option<Container<Power<T>>> {
  165. +        let exponent = exponent.into();
  166. +        if exponent <= 0.into() {
  167.              None
  168.          } else {
  169. -            Some(Product::new(coefficient.clone(), powers.clone()).insert(self))
  170. -        }
  171. -    }
  172. -
  173. -    fn make_power(&mut self,
  174. -                  primitive: &Container<Primitive<T>>,
  175. -                  exponent: &T::Exponent)
  176. -                  -> Option<Container<Power<T>>> {
  177. -        if exponent.clone() <= T::Exponent::zero() {
  178. -            None
  179. -        } else {
  180. -            Some(Power::new(exponent.clone(), primitive.clone()).insert(self))
  181. +            Some(self.secret_make_power(primitive, exponent))
  182.          }
  183.      }
  184.  
  185. -    fn make_system_variable(&mut self, index: &T::SystemVariable) -> Container<Primitive<T>> {
  186. -        Primitive::SystemVariable(index.clone()).insert(self)
  187. +    /// Makes a canonicalized SystemVariable from the given components.
  188. +    pub fn make_system_variable(&mut self, index: T::SystemVariable) -> Container<Primitive<T>> {
  189. +        Primitive::SystemVariable(index).insert(self)
  190.      }
  191.  
  192. -    fn make_parameter(&mut self, name: &T::Parameter) -> Container<Primitive<T>> {
  193. -        Primitive::Parameter(name.clone()).insert(self)
  194. +    /// Makes a canonicalized Parameter from the given components.
  195. +    pub fn make_parameter(&mut self, name: T::Parameter) -> Container<Primitive<T>> {
  196. +        Primitive::Parameter(name).insert(self)
  197.      }
  198.  
  199. -    fn make_input(&mut self, index: &T::Input) -> Container<Primitive<T>> {
  200. -        Primitive::Input(index.clone()).insert(self)
  201. +    /// Makes a canonicalized Input from the given components.
  202. +    pub fn make_input(&mut self, index: T::Input) -> Container<Primitive<T>> {
  203. +        Primitive::Input(index).insert(self)
  204.      }
  205.  
  206. -    fn make_sigmoid(&mut self, minus: bool, sum: &Container<Sum<T>>) -> Container<Primitive<T>> {
  207. -        if sum.minus || (sum.terms.len() == 0 && sum.constant < T::Constant::zero()) {
  208. +    /// Makes a canonicalized Sigmoid from the given components.
  209. +    pub fn make_sigmoid(&mut self, minus: bool, sum: Container<Sum<T>>) -> Container<Primitive<T>> {
  210. +        if sum.minus || (sum.terms.len() == 0 && sum.constant < 0.into()) {
  211.              let sum = sum.negate(self);
  212. -            self.make_sigmoid(!minus, &sum)
  213. +            self.make_sigmoid(!minus, sum)
  214.          } else {
  215. -            Primitive::Sigmoid(minus, sum.clone()).insert(self)
  216. +            Primitive::Sigmoid(minus, sum).insert(self)
  217.          }
  218.      }
  219.  
  220. -    fn negate_products(&mut self,
  221. -                       terms: &Vec<Container<Product<T>>>)
  222. -                       -> Vec<Container<Product<T>>> {
  223. -        terms.iter().map(|p| p.negate(self)).collect::<Vec<_>>()
  224. -    }
  225. -    fn number(&mut self, number: &T::Constant) -> Container<Sum<T>> {
  226. -        self.make_sum(false, number, &vec![])
  227. +    fn negate_products<C>(&mut self, terms: C) -> Vec<Container<Product<T>>>
  228. +        where C: IntoIterator<Item = Container<Product<T>>>
  229. +    {
  230. +        terms.into_iter().map(|p| p.negate(self)).collect()
  231.      }
  232.  
  233. -    fn add<A: Node<T>, B: Node<T>>(&mut self,
  234. -                                   left: &Container<A>,
  235. -                                   right: &Container<B>)
  236. -                                   -> Container<Sum<T>> {
  237. +    /// Makes a constant Sum.
  238. +    pub fn number<N: Into<T::Constant>>(&mut self, number: N) -> Container<Sum<T>> {
  239. +        self.make_sum(false, number, vec![])
  240. +    }
  241. +
  242. +    /// Adds two Nodes together, producing a Sum.
  243. +    pub fn add<A: Node<T>, B: Node<T>>(&mut self,
  244. +                                       left: Container<A>,
  245. +                                       right: Container<B>)
  246. +                                       -> Container<Sum<T>> {
  247.          let left_sum = left.as_sum(self);
  248.          let right_sum = right.as_sum(self);
  249.          let constant = left_sum.at_zero() + right_sum.at_zero();
  250.          if left_sum.terms.len() == 0 {
  251. -            return right_sum.adjust_constant(&constant, self);
  252. +            return right_sum.adjust_constant(constant, self);
  253.          }
  254.          if right_sum.terms.len() == 0 {
  255. -            return left_sum.adjust_constant(&constant, self);
  256. +            return left_sum.adjust_constant(constant, self);
  257.          }
  258.  
  259.          // Begin port of weird logic
  260. @@ -170,8 +213,9 @@
  261.                  Ordering::Equal => {
  262.                      let coefficient = left_term.coefficient.clone() +
  263.                                        right_term.coefficient.clone();
  264. -                    if coefficient != T::Constant::zero() {
  265. -                        terms.push(self.make_product(&coefficient, &left_term.powers).unwrap());
  266. +                    match self.make_product(coefficient, left_term.powers.iter().cloned()) {
  267. +                        Some(product) => terms.push(product),
  268. +                        None => (),
  269.                      }
  270.                      left_index += 1;
  271.                      right_index += 1;
  272. @@ -182,40 +226,41 @@
  273.                  }
  274.              }
  275.              if left_index == left_length {
  276. -                terms.extend_from_slice(&right_terms[..].split_at(right_index).1);
  277. +                terms.extend_from_slice(&right_terms[right_index..]);
  278.              }
  279.              if right_index == right_length {
  280. -                terms.extend_from_slice(&left_terms[..].split_at(left_index).1);
  281. +                terms.extend_from_slice(&left_terms[left_index..]);
  282.              }
  283.          }
  284.  
  285.          terms.shrink_to_fit(); // Why not.
  286.  
  287. -        if terms.len() == 0 || terms[0].coefficient >= T::Constant::zero() {
  288. -            self.make_sum(false, &constant, &terms)
  289. +        if terms.len() == 0 || terms[0].coefficient >= 0.into() {
  290. +            self.make_sum(false, constant, terms)
  291.          } else {
  292. -            let terms = self.negate_products(&terms);
  293. -            self.make_sum(true, &constant, &terms)
  294. +            let terms = self.negate_products(terms);
  295. +            self.make_sum(true, constant, terms)
  296.          }
  297.      }
  298.  
  299. -    fn multiply<A: Node<T>, B: Node<T>>(&mut self,
  300. -                                        left: &Container<A>,
  301. -                                        right: &Container<B>)
  302. -                                        -> Container<Sum<T>> {
  303. +    /// Multiplies two Nodes together, producing a Sum.
  304. +    pub fn multiply<A: Node<T>, B: Node<T>>(&mut self,
  305. +                                            left: Container<A>,
  306. +                                            right: Container<B>)
  307. +                                            -> Container<Sum<T>> {
  308.          let left_sum = left.as_sum(self);
  309.          let right_sum = right.as_sum(self);
  310.          let minus = left_sum.minus != right_sum.minus;
  311.  
  312.          let first_constant = left_sum.constant.clone() * right_sum.constant.clone();
  313. -        let first = self.make_sum(minus, &first_constant, &vec![]);
  314. +        let first = self.make_sum(minus, first_constant, vec![]);
  315.  
  316.          let mut outer_terms = Vec::with_capacity(right_sum.terms.len());
  317. -        if left_sum.constant != T::Constant::zero() {
  318. +        if left_sum.constant != 0.into() {
  319.              for term in &right_sum.terms {
  320. -                let outer_coefficient = left_sum.constant.clone() * term.coefficient.clone();
  321. -                let outer_term = self.make_product(&outer_coefficient, &term.powers).unwrap();
  322. -                if left_sum.constant < T::Constant::zero() {
  323. +                let outer_coefficient = left_sum.constant.clone().into() * term.coefficient.clone();
  324. +                let outer_term = self.secret_make_product(outer_coefficient, term.powers.clone());
  325. +                if left_sum.constant < 0.into() {
  326.                      outer_terms.push(outer_term.negate(self));
  327.                  } else {
  328.                      outer_terms.push(outer_term);
  329. @@ -223,17 +268,16 @@
  330.              }
  331.          }
  332.          outer_terms.shrink_to_fit();
  333. -        let outer = self.make_sum((left_sum.constant < T::Constant::zero()) != minus,
  334. -                                  &T::Constant::zero(),
  335. -                                  &outer_terms);
  336. +        let outer = self.make_sum((left_sum.constant < 0.into()) != minus, 0, outer_terms);
  337.  
  338.          // Should be possible to turn these into functions.
  339.          let mut inner_terms = Vec::with_capacity(left_sum.terms.len());
  340. -        if right_sum.constant != T::Constant::zero() {
  341. +        if right_sum.constant != 0.into() {
  342.              for term in &left_sum.terms {
  343. -                let inner_coefficient = right_sum.constant.clone() * term.coefficient.clone();
  344. -                let inner_term = self.make_product(&inner_coefficient, &term.powers).unwrap();
  345. -                if right_sum.constant < T::Constant::zero() {
  346. +                let inner_coefficient = right_sum.constant.clone().into() *
  347. +                                        term.coefficient.clone();
  348. +                let inner_term = self.secret_make_product(inner_coefficient, term.powers.clone());
  349. +                if right_sum.constant < 0.into() {
  350.                      inner_terms.push(inner_term.negate(self));
  351.                  } else {
  352.                      inner_terms.push(inner_term);
  353. @@ -241,24 +285,22 @@
  354.              }
  355.          }
  356.          inner_terms.shrink_to_fit();
  357. -        let inner = self.make_sum((right_sum.constant < T::Constant::zero()) != minus,
  358. -                                  &T::Constant::zero(),
  359. -                                  &inner_terms);
  360. +        let inner = self.make_sum((right_sum.constant < 0.into()) != minus, 0, inner_terms);
  361.  
  362. -        let mut last = self.number(&T::Constant::zero());
  363. +        let mut last = self.number(0);
  364.          for left_term in &left_sum.terms {
  365.              for right_term in &right_sum.terms {
  366. -                let combined_terms = self._combine_terms(&left_term.powers, &right_term.powers);
  367.                  let last_coefficient = left_term.coefficient.clone() *
  368.                                         right_term.coefficient.clone();
  369. -                let product = self.make_product(&last_coefficient, &combined_terms).unwrap();
  370. -                last = self.add(&last, &product);
  371. +                let combined_terms = self._combine_terms(&left_term.powers, &right_term.powers);
  372. +                let product = self.secret_make_product(last_coefficient, combined_terms);
  373. +                last = self.add(last, product);
  374.              }
  375.          }
  376.  
  377. -        last = self.add(&last, &inner);
  378. -        last = self.add(&last, &outer);
  379. -        self.add(&last, &first)
  380. +        last = self.add(last, inner);
  381. +        last = self.add(last, outer);
  382. +        self.add(last, first)
  383.      }
  384.  
  385.      fn _combine_terms(&mut self,
  386. @@ -275,14 +317,14 @@
  387.          while left_index < left_length && right_index < right_length {
  388.              let ref left_power = left[left_index];
  389.              let ref right_power = right[right_length];
  390. -            match left_power.fuzzy_cmp(&right_power) {
  391. +            match left_power.fuzzy_cmp(right_power.as_ref()) {
  392.                  Ordering::Less => {
  393.                      powers.push(left_power.clone());
  394.                      left_index += 1;
  395.                  }
  396.                  Ordering::Equal => {
  397.                      let exponent = left_power.exponent.clone() + right_power.exponent.clone();
  398. -                    powers.push(self.make_power(&left_power.primitive, &exponent).unwrap());
  399. +                    powers.push(self.secret_make_power(left_power.primitive.clone(), exponent));
  400.                      left_index += 1;
  401.                      right_index += 1;
  402.                  }
  403. @@ -292,10 +334,10 @@
  404.                  }
  405.              }
  406.              if left_index == left_length {
  407. -                powers.extend_from_slice(&right[..].split_at(right_index).1);
  408. +                powers.extend_from_slice(&right[right_index..]);
  409.              }
  410.              if right_index == right_length {
  411. -                powers.extend_from_slice(&left[..].split_at(left_index).1);
  412. +                powers.extend_from_slice(&left[left_index..]);
  413.              }
  414.          }
  415.  
  416. @@ -306,55 +348,37 @@
  417.      }
  418.  }
  419.  
  420. -trait Summable<T: NodeData> {
  421. -    fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>>;
  422. -}
  423. -
  424. -trait Productable<T: NodeData> {
  425. -    fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>>;
  426. -}
  427. -
  428. -fn _as_sum<T: NodeData>(productable: &Productable<T>,
  429. -                        env: &mut Environment<T>)
  430. -                        -> Container<Sum<T>> {
  431. -    let product = productable.as_product(env);
  432. -    if product.coefficient < T::Constant::zero() {
  433. -        let product = product.negate(env);
  434. -        env.make_sum(true, &T::Constant::zero(), &vec![product])
  435. -    } else {
  436. -        env.make_sum(false, &T::Constant::zero(), &vec![product])
  437. -    }
  438. -}
  439. -
  440. -impl<T: NodeData> Summable<T> for Product<T> {
  441. -    fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
  442. -        _as_sum(self, env)
  443. -    }
  444. -}
  445. -
  446. -impl<T: NodeData> Summable<T> for Power<T> {
  447. -    fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
  448. -        _as_sum(self, env)
  449. -    }
  450. -}
  451. -
  452. -impl<T: NodeData> Summable<T> for Primitive<T> {
  453. -    fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
  454. -        _as_sum(self, env)
  455. -    }
  456. -}
  457. -
  458. -trait Node<T: NodeData>: Summable<T> {
  459. +pub trait Node<T: NodeData>: Summable<T> + FancyClone<T> {
  460.      fn partial_derivative(&self,
  461.                            variable: &T::Parameter,
  462.                            env: &mut Environment<T>)
  463. -                          -> Container<Sum<T>>;
  464. -
  465. -    fn fancy_clone(&self, env: &mut Environment<T>) -> Self;
  466. +                          -> Container<Sum<T>>
  467. +        where Self: Sized + Clone + Eq + Hash
  468. +    {
  469. +        if !self.get_derivative_cache(env).contains_key(&(self.clone(), variable.clone())) {
  470. +            let key = (self.fancy_clone(env), variable.clone());
  471. +            let result = self.derivative_impl(variable, env);
  472. +            self.get_derivative_cache(env).insert(key, result);
  473. +        }
  474. +        match self.get_derivative_cache(env).get(&(self.clone(), variable.clone())) {
  475. +            Some(derivative) => derivative.clone(),
  476. +            None => unreachable!(),
  477. +        }
  478. +    }
  479.  
  480.      fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Self>
  481.          where Self: Sized;
  482.  
  483. +    fn get_derivative_cache<'a>(&self,
  484. +                                env: &'a mut Environment<T>)
  485. +                                -> &'a mut HashMap<(Self, T::Parameter), Container<Sum<T>>>
  486. +        where Self: Sized;
  487. +
  488. +    fn derivative_impl(&self,
  489. +                       variable: &T::Parameter,
  490. +                       env: &mut Environment<T>)
  491. +                       -> Container<Sum<T>>;
  492. +
  493.      fn insert(&self, env: &mut Environment<T>) -> Container<Self>
  494.          where Self: Sized + Clone + Eq + Hash
  495.      {
  496. @@ -364,173 +388,142 @@
  497.              self.get_storage(env).insert(key, container(value));
  498.          }
  499.          match self.get_storage(env).get(self) {
  500. -                Some(node) => node,
  501. -                None => unreachable!(),
  502. -            }
  503. -            .clone()
  504. -    }
  505. -}
  506. -
  507. -impl<T: NodeData> Summable<T> for Sum<T> {
  508. -    fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
  509. -        self.insert(env)
  510. +            Some(node) => node.clone(),
  511. +            None => unreachable!(),
  512. +        }
  513.      }
  514.  }
  515.  
  516.  impl<T: NodeData> Node<T> for Sum<T> {
  517. -    fn partial_derivative(&self,
  518. -                          variable: &T::Parameter,
  519. -                          env: &mut Environment<T>)
  520. -                          -> Container<Sum<T>> {
  521. -        if !env.sum_derivative.contains_key(&(self.clone(), variable.clone())) {
  522. -            let key = (self.fancy_clone(env), variable.clone());
  523. -            let mut result = env.number(&T::Constant::zero());
  524. -            for term in &self.terms {
  525. -                let partial = term.partial_derivative(variable, env);
  526. -                result = env.add(&result, &partial);
  527. -            }
  528. -            if self.minus {
  529. -                result = result.negate(env);
  530. -            }
  531. -            env.sum_derivative.insert(key, result);
  532. +    fn derivative_impl(&self,
  533. +                       variable: &T::Parameter,
  534. +                       env: &mut Environment<T>)
  535. +                       -> Container<Sum<T>> {
  536. +        let mut result = env.number(0);
  537. +        for term in &self.terms {
  538. +            let partial = term.partial_derivative(variable, env);
  539. +            result = env.add(result, partial);
  540.          }
  541. -        match env.sum_derivative.get(&(self.clone(), variable.clone())) {
  542. -            Some(derivative) => derivative.clone(),
  543. -            None => unreachable!(),
  544. +        if self.minus {
  545. +            result = result.negate(env);
  546.          }
  547. -    }
  548. -
  549. -    fn fancy_clone(&self, env: &mut Environment<T>) -> Sum<T> {
  550. -        Sum {
  551. -            pre_hash: self.pre_hash.clone(),
  552. -            minus: self.minus.clone(),
  553. -            constant: self.constant.clone(),
  554. -            terms: self.terms
  555. -                .iter()
  556. -                .map(|p| p.insert(env))
  557. -                .collect::<Vec<_>>(),
  558. -        }
  559. +        result
  560.      }
  561.  
  562.      fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Sum<T>> {
  563.          &mut env.sum
  564.      }
  565. -}
  566.  
  567. -impl<T: NodeData> Productable<T> for Product<T> {
  568. -    fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>> {
  569. -        self.insert(env)
  570. +    fn get_derivative_cache<'a>(&self,
  571. +                                env: &'a mut Environment<T>)
  572. +                                -> &'a mut HashMap<(Sum<T>, T::Parameter), Container<Sum<T>>>
  573. +        where Self: Sized
  574. +    {
  575. +        &mut env.sum_derivative
  576.      }
  577.  }
  578.  
  579.  impl<T: NodeData> Node<T> for Product<T> {
  580. -    fn partial_derivative(&self,
  581. -                          variable: &T::Parameter,
  582. -                          env: &mut Environment<T>)
  583. -                          -> Container<Sum<T>> {
  584. -        if !env.product_derivative.contains_key(&(self.clone(), variable.clone())) {
  585. -            let key = (self.fancy_clone(env), variable.clone());
  586. -            let result = if self.powers.len() == 1 {
  587. -                self.powers[0].partial_derivative(variable, env)
  588. -            } else {
  589. -                let mut result = env.number(&T::Constant::zero());
  590. -                unimplemented!()
  591. -                /*for (index, focus) in self.powers.iter().enumerate() {
  592. -                    let remainder = env.make_product(T::Constant::one(), )
  593. -                }*/
  594. -            };
  595. -            env.product_derivative.insert(key, result);
  596. -        }
  597. -        match env.product_derivative.get(&(self.clone(), variable.clone())) {
  598. -            Some(derivative) => derivative.clone(),
  599. -            None => unreachable!(),
  600. -        }
  601. -    }
  602. -
  603. -    fn fancy_clone(&self, env: &mut Environment<T>) -> Product<T> {
  604. -        Product {
  605. -            pre_hash: self.pre_hash.clone(),
  606. -            coefficient: self.coefficient.clone(),
  607. -            powers: self.powers
  608. -                .iter()
  609. -                .map(|p| p.insert(env))
  610. -                .collect::<Vec<_>>(),
  611. -        }
  612. +    fn derivative_impl(&self,
  613. +                       variable: &T::Parameter,
  614. +                       env: &mut Environment<T>)
  615. +                       -> Container<Sum<T>> {
  616. +        let coefficient = env.number(self.coefficient.clone());
  617. +        let result = if self.powers.len() == 1 {
  618. +            self.powers[0].partial_derivative(variable, env)
  619. +        } else {
  620. +            let mut result = env.number(0);
  621. +            for (index, focus) in self.powers.iter().enumerate() {
  622. +                let remainder = env.secret_make_product(1,
  623. +                                                        self.powers[..index]
  624. +                                                            .iter()
  625. +                                                            .chain(self.powers[index + 1..]
  626. +                                                                       .iter())
  627. +                                                            .cloned()
  628. +                                                            .collect());
  629. +                let partial_derivative = focus.partial_derivative(variable, env);
  630. +                let product = env.multiply(remainder, partial_derivative);
  631. +                result = env.add(result, product);
  632. +            }
  633. +            result
  634. +        };
  635. +        env.multiply(result, coefficient)
  636.      }
  637.  
  638.      fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Product<T>> {
  639.          &mut env.product
  640.      }
  641. -}
  642.  
  643. -impl<T: NodeData> Productable<T> for Power<T> {
  644. -    fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>> {
  645. -        let power = self.insert(env);
  646. -        env.make_product(&T::Constant::one(), &vec![power]).unwrap()
  647. +    fn get_derivative_cache<'a>(&self,
  648. +                                env: &'a mut Environment<T>)
  649. +                                -> &'a mut HashMap<(Product<T>, T::Parameter), Container<Sum<T>>>
  650. +        where Self: Sized
  651. +    {
  652. +        &mut env.product_derivative
  653.      }
  654.  }
  655.  
  656.  impl<T: NodeData> Node<T> for Power<T> {
  657. -    fn partial_derivative(&self,
  658. -                          variable: &T::Parameter,
  659. -                          env: &mut Environment<T>)
  660. -                          -> Container<Sum<T>> {
  661. -        unimplemented!();
  662. -    }
  663. -
  664. -    fn fancy_clone(&self, env: &mut Environment<T>) -> Power<T> {
  665. -        Power {
  666. -            exponent: self.exponent.clone(),
  667. -            primitive: self.primitive.insert(env),
  668. +    fn derivative_impl(&self,
  669. +                       variable: &T::Parameter,
  670. +                       env: &mut Environment<T>)
  671. +                       -> Container<Sum<T>> {
  672. +        let derivative = self.primitive.partial_derivative(variable, env);
  673. +        if self.exponent == 1.into() {
  674. +            derivative
  675. +        } else {
  676. +            let coefficient = env.number(self.exponent.clone());
  677. +            let remainder = env.secret_make_power(self.primitive.clone(),
  678. +                                                  self.exponent.clone() - 1.into());
  679. +            let product = env.multiply(derivative, remainder);
  680. +            env.multiply(coefficient, product)
  681.          }
  682.      }
  683.  
  684.      fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Power<T>> {
  685.          &mut env.power
  686.      }
  687. -}
  688.  
  689. -impl<T: NodeData> Primitive<T> {
  690. -    fn as_power(&self, env: &mut Environment<T>) -> Container<Power<T>> {
  691. -        let primitive = self.insert(env);
  692. -        env.make_power(&primitive, &T::Exponent::one()).unwrap()
  693. -    }
  694. -}
  695. -
  696. -impl<T: NodeData> Productable<T> for Primitive<T> {
  697. -    fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>> {
  698. -        match self {
  699. -            &Primitive::Sigmoid(true, ref sum) => {
  700. -                let sigmoid = env.make_sigmoid(false, sum);
  701. -                let power = sigmoid.as_power(env);
  702. -                env.make_product(&-T::Constant::one(), &vec![power]).unwrap()
  703. -            }
  704. -            _ => {
  705. -                let power = self.as_power(env);
  706. -                power.as_product(env)
  707. -            }
  708. -        }
  709. +    fn get_derivative_cache<'a>(&self,
  710. +                                env: &'a mut Environment<T>)
  711. +                                -> &'a mut HashMap<(Power<T>, T::Parameter), Container<Sum<T>>>
  712. +        where Self: Sized
  713. +    {
  714. +        &mut env.power_derivative
  715.      }
  716.  }
  717.  
  718.  impl<T: NodeData> Node<T> for Primitive<T> {
  719. -    fn partial_derivative(&self,
  720. -                          variable: &T::Parameter,
  721. -                          env: &mut Environment<T>)
  722. -                          -> Container<Sum<T>> {
  723. -        unimplemented!();
  724. -    }
  725. -
  726. -    fn fancy_clone(&self, env: &mut Environment<T>) -> Primitive<T> {
  727. +    fn derivative_impl(&self,
  728. +                       variable: &T::Parameter,
  729. +                       env: &mut Environment<T>)
  730. +                       -> Container<Sum<T>> {
  731.          match self {
  732. -            &Primitive::Sigmoid(minus, ref a) => Primitive::Sigmoid(minus, a.insert(env)),
  733. -            _ => self.clone(),
  734. +            &Primitive::Sigmoid(minus, ref sum) => {
  735. +                let derivative = sum.partial_derivative(variable, env);
  736. +                let self_in_env = self.insert(env);
  737. +                let self_squared = env.secret_make_power(self_in_env, 2);
  738. +                let minus_self_squared = env.secret_make_product(-1, vec![self_squared]);
  739. +                let sum = env.make_sum(minus, 1, vec![minus_self_squared]);
  740. +                env.multiply(derivative, sum)
  741. +            }
  742. +            &Primitive::Parameter(ref p) => env.number(if p == variable { 1 } else { 0 }),
  743. +            _ => env.number(0),
  744.          }
  745.      }
  746.  
  747.      fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Primitive<T>> {
  748.          &mut env.primitive
  749.      }
  750. +
  751. +    fn get_derivative_cache<'a>
  752. +        (&self,
  753. +         env: &'a mut Environment<T>)
  754. +         -> &'a mut HashMap<(Primitive<T>, T::Parameter), Container<Sum<T>>>
  755. +        where Self: Sized
  756. +    {
  757. +        &mut env.primitive_derivative
  758. +    }
  759.  }
  760.  
  761.  #[cfg(test)]
  762. diff --git a/rust/adagio/src/node_conversions.rs b/rust/adagio/src/node_conversions.rs
  763. new file mode 100644
  764. --- /dev/null
  765. +++ b/rust/adagio/src/node_conversions.rs
  766. @@ -0,0 +1,83 @@
  767. +use {Environment, Node};
  768. +
  769. +use node_defs::{Container, NodeData, Sum, Product, Power, Primitive};
  770. +
  771. +pub trait Summable<T: NodeData> {
  772. +    fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>>;
  773. +}
  774. +
  775. +trait Productable<T: NodeData> {
  776. +    fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>>;
  777. +}
  778. +
  779. +fn _as_sum<T: NodeData, U: Productable<T>>(productable: &U,
  780. +                                           env: &mut Environment<T>)
  781. +                                           -> Container<Sum<T>> {
  782. +    let product = productable.as_product(env);
  783. +    if product.coefficient < 0.into() {
  784. +        let product = product.as_ref().negate(env);
  785. +        env.make_sum(true, 0, vec![product])
  786. +    } else {
  787. +        env.make_sum(false, 0, vec![product])
  788. +    }
  789. +}
  790. +
  791. +impl<T: NodeData> Productable<T> for Product<T> {
  792. +    fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>> {
  793. +        self.insert(env)
  794. +    }
  795. +}
  796. +
  797. +impl<T: NodeData> Productable<T> for Power<T> {
  798. +    fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>> {
  799. +        let power = self.insert(env);
  800. +        env.secret_make_product(1, vec![power])
  801. +    }
  802. +}
  803. +
  804. +impl<T: NodeData> Primitive<T> {
  805. +    fn as_power(&self, env: &mut Environment<T>) -> Container<Power<T>> {
  806. +        let primitive = self.insert(env);
  807. +        env.secret_make_power(primitive, 1)
  808. +    }
  809. +}
  810. +
  811. +impl<T: NodeData> Productable<T> for Primitive<T> {
  812. +    fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>> {
  813. +        match self {
  814. +            &Primitive::Sigmoid(true, ref sum) => {
  815. +                let sigmoid = env.make_sigmoid(false, sum.clone());
  816. +                let power = sigmoid.as_power(env);
  817. +                env.secret_make_product(-1, vec![power])
  818. +            }
  819. +            _ => {
  820. +                let power = self.as_power(env);
  821. +                power.as_product(env)
  822. +            }
  823. +        }
  824. +    }
  825. +}
  826. +
  827. +impl<T: NodeData> Summable<T> for Sum<T> {
  828. +    fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
  829. +        self.insert(env)
  830. +    }
  831. +}
  832. +
  833. +impl<T: NodeData> Summable<T> for Product<T> {
  834. +    fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
  835. +        _as_sum(self, env)
  836. +    }
  837. +}
  838. +
  839. +impl<T: NodeData> Summable<T> for Power<T> {
  840. +    fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
  841. +        _as_sum(self, env)
  842. +    }
  843. +}
  844. +
  845. +impl<T: NodeData> Summable<T> for Primitive<T> {
  846. +    fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
  847. +        _as_sum(self, env)
  848. +    }
  849. +}
  850. diff --git a/rust/adagio/src/node_defs.rs b/rust/adagio/src/node_defs.rs
  851. --- a/rust/adagio/src/node_defs.rs
  852. +++ b/rust/adagio/src/node_defs.rs
  853. @@ -5,6 +5,8 @@
  854.  
  855.  use num::{Integer, Signed};
  856.  
  857. +use {Environment, Node};
  858. +
  859.  pub type Container<T> = Arc<T>;
  860.  
  861.  pub fn container<T>(data: T) -> Container<T> {
  862. @@ -12,23 +14,28 @@
  863.  }
  864.  
  865.  pub trait NodeData {
  866. -    type Constant: Hash + Ord + Clone + Signed;
  867. -    type Exponent: Hash + Clone + Integer;
  868. +    type Constant: Hash + Ord + Clone + Signed + Into<Self::Coefficient> + From<i32>;
  869. +    type Coefficient: Hash + Ord + Clone + Signed + Into<Self::Constant> + From<i32>;
  870. +    type Exponent: Hash + Clone + Integer + Into<Self::Constant> + From<i32>;
  871.      type Input: Hash + Ord + Clone;
  872.      type SystemVariable: Hash + Ord + Clone;
  873.      type Parameter: Hash + Ord + Clone;
  874.  }
  875.  
  876. +pub trait FancyClone<T: NodeData> {
  877. +    fn fancy_clone(&self, env: &mut Environment<T>) -> Self;
  878. +}
  879. +
  880.  pub struct Sum<T: NodeData> {
  881. -    pub pre_hash: u64, //I would like these fields to be private if possible.
  882. +    pre_hash: u64, //I would like these fields to be private if possible.
  883.      pub minus: bool,
  884.      pub constant: T::Constant,
  885.      pub terms: Vec<Container<Product<T>>>,
  886.  }
  887.  
  888.  pub struct Product<T: NodeData> {
  889. -    pub pre_hash: u64, //I would like these fields to be private if possible.
  890. -    pub coefficient: T::Constant,
  891. +    pre_hash: u64, //I would like these fields to be private if possible.
  892. +    pub coefficient: T::Coefficient,
  893.      pub powers: Vec<Container<Power<T>>>,
  894.  }
  895.  
  896. @@ -145,6 +152,51 @@
  897.      }
  898.  }
  899.  
  900. +impl<T: NodeData> FancyClone<T> for Sum<T> {
  901. +    fn fancy_clone(&self, env: &mut Environment<T>) -> Sum<T> {
  902. +        Sum {
  903. +            pre_hash: self.pre_hash.clone(),
  904. +            minus: self.minus.clone(),
  905. +            constant: self.constant.clone(),
  906. +            terms: self.terms
  907. +                .iter()
  908. +                .map(|p| p.insert(env))
  909. +                .collect(),
  910. +        }
  911. +    }
  912. +}
  913. +
  914. +impl<T: NodeData> FancyClone<T> for Product<T> {
  915. +    fn fancy_clone(&self, env: &mut Environment<T>) -> Product<T> {
  916. +        Product {
  917. +            pre_hash: self.pre_hash.clone(),
  918. +            coefficient: self.coefficient.clone(),
  919. +            powers: self.powers
  920. +                .iter()
  921. +                .map(|p| p.insert(env))
  922. +                .collect(),
  923. +        }
  924. +    }
  925. +}
  926. +
  927. +impl<T: NodeData> FancyClone<T> for Power<T> {
  928. +    fn fancy_clone(&self, env: &mut Environment<T>) -> Power<T> {
  929. +        Power {
  930. +            exponent: self.exponent.clone(),
  931. +            primitive: self.primitive.insert(env),
  932. +        }
  933. +    }
  934. +}
  935. +
  936. +impl<T: NodeData> FancyClone<T> for Primitive<T> {
  937. +    fn fancy_clone(&self, env: &mut Environment<T>) -> Primitive<T> {
  938. +        match self {
  939. +            &Primitive::Sigmoid(minus, ref a) => Primitive::Sigmoid(minus, a.insert(env)),
  940. +            _ => self.clone(),
  941. +        }
  942. +    }
  943. +}
  944. +
  945.  impl<T: NodeData> Clone for Sum<T> {
  946.      fn clone(&self) -> Sum<T> {
  947.          Sum {
  948. @@ -202,7 +254,7 @@
  949.  }
  950.  
  951.  impl<T: NodeData> Product<T> {
  952. -    pub fn new(coefficient: T::Constant, powers: Vec<Container<Power<T>>>) -> Product<T> {
  953. +    pub fn new(coefficient: T::Coefficient, powers: Vec<Container<Power<T>>>) -> Product<T> {
  954.          let mut s = DefaultHasher::new();
  955.          coefficient.hash(&mut s);
  956.          powers.hash(&mut s);
  957. diff --git a/rust/adagio/src/node_ordering.rs b/rust/adagio/src/node_ordering.rs
  958. --- a/rust/adagio/src/node_ordering.rs
  959. +++ b/rust/adagio/src/node_ordering.rs
  960. @@ -1,6 +1,6 @@
  961.  use std::cmp::Ordering;
  962.  
  963. -use node_defs::{Container, NodeData, Sum, Product, Power, Primitive};
  964. +use node_defs::{NodeData, Sum, Product, Power, Primitive};
  965.  
  966.  #[derive(Eq, Ord, PartialEq, PartialOrd)]
  967.  enum PrimitiveType {
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement