Advertisement
mwchase

LC:NN upload 5a

Apr 2nd, 2017
111
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Diff 8.40 KB | None | 0 0
  1. diff --git a/rust/adagio/src/eval.rs b/rust/adagio/src/eval.rs
  2. --- a/rust/adagio/src/eval.rs
  3. +++ b/rust/adagio/src/eval.rs
  4. @@ -0,0 +1,28 @@
  5. +use std::marker::PhantomData;
  6. +
  7. +use void::Void;
  8. +
  9. +trait Input<I, T> {
  10. +    fn get(&self, index: I) -> T where T: Clone;
  11. +}
  12. +
  13. +trait Parameter<I, T> {
  14. +    fn get(&self, index: I) -> T where T: Clone;
  15. +}
  16. +
  17. +trait Constant<I, T> {
  18. +    fn get(&self, index: I) -> T where T: Clone;
  19. +}
  20. +
  21. +struct InputVec<T> (Vec<T>);
  22. +struct ParamVec<T> (Vec<T>);
  23. +struct UninhabitableConstant<T> (Void, PhantomData<T>);
  24. +
  25. +trait Eval<T> {
  26. +    type InputIndex;
  27. +    type Input: Input<Self::InputIndex, T>;
  28. +    type ParameterIndex;
  29. +    type Parameter: Parameter<Self::ParameterIndex, T>;
  30. +    type ConstantIndex;
  31. +    type Constant: Constant<Self::ConstantIndex, T>;
  32. +}
  33. diff --git a/rust/adagio/src/lib.rs b/rust/adagio/src/lib.rs
  34. --- a/rust/adagio/src/lib.rs
  35. +++ b/rust/adagio/src/lib.rs
  36. @@ -1,9 +1,11 @@
  37.  extern crate num;
  38. +extern crate void;
  39.  
  40.  use std::cmp::Ordering;
  41.  use std::collections::HashMap;
  42.  use std::hash::Hash;
  43.  
  44. +mod eval;
  45.  mod node_conversions;
  46.  mod node_defs;
  47.  mod node_ordering;
  48. @@ -68,10 +70,7 @@
  49.          if self.minus {
  50.              env.negate_products(self.terms.iter().cloned())
  51.          } else {
  52. -            self.terms
  53. -                .iter()
  54. -                .map(|p| p.insert(env))
  55. -                .collect()
  56. +            self.terms.iter().map(|p| p.insert(env)).collect()
  57.          }
  58.      }
  59.  }
  60. @@ -141,9 +140,9 @@
  61.          }
  62.      }
  63.  
  64. -    /// Makes a canonicalized SystemVariable from the given components.
  65. -    pub fn make_system_variable(&mut self, index: T::SystemVariable) -> Container<Primitive<T>> {
  66. -        Primitive::SystemVariable(index).insert(self)
  67. +    /// Makes a canonicalized Real from the given components.
  68. +    pub fn make_real(&mut self, index: T::Real) -> Container<Primitive<T>> {
  69. +        Primitive::Real(index).insert(self)
  70.      }
  71.  
  72.      /// Makes a canonicalized Parameter from the given components.
  73. @@ -355,15 +354,13 @@
  74.                            -> Container<Sum<T>>
  75.          where Self: Sized + Clone + Eq + Hash
  76.      {
  77. -        if !self.get_derivative_cache(env).contains_key(&(self.clone(), variable.clone())) {
  78. +        if !self.get_derivative_cache(env)
  79. +                .contains_key(&(self.clone(), variable.clone())) {
  80.              let key = (self.fancy_clone(env), variable.clone());
  81.              let result = self.derivative_impl(variable, env);
  82.              self.get_derivative_cache(env).insert(key, result);
  83.          }
  84. -        match self.get_derivative_cache(env).get(&(self.clone(), variable.clone())) {
  85. -            Some(derivative) => derivative.clone(),
  86. -            None => unreachable!(),
  87. -        }
  88. +        self.get_derivative_cache(env)[&(self.clone(), variable.clone())].clone()
  89.      }
  90.  
  91.      fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Self>
  92. @@ -387,10 +384,7 @@
  93.              let value = key.clone();
  94.              self.get_storage(env).insert(key, container(value));
  95.          }
  96. -        match self.get_storage(env).get(self) {
  97. -            Some(node) => node.clone(),
  98. -            None => unreachable!(),
  99. -        }
  100. +        self.get_storage(env)[self].clone()
  101.      }
  102.  }
  103.  
  104. diff --git a/rust/adagio/src/node_defs.rs b/rust/adagio/src/node_defs.rs
  105. --- a/rust/adagio/src/node_defs.rs
  106. +++ b/rust/adagio/src/node_defs.rs
  107. @@ -18,7 +18,7 @@
  108.      type Coefficient: Hash + Ord + Clone + Signed + Into<Self::Constant> + From<i32>;
  109.      type Exponent: Hash + Clone + Integer + Into<Self::Constant> + From<i32>;
  110.      type Input: Hash + Ord + Clone;
  111. -    type SystemVariable: Hash + Ord + Clone;
  112. +    type Real: Hash + Ord + Clone;
  113.      type Parameter: Hash + Ord + Clone;
  114.  }
  115.  
  116. @@ -27,14 +27,14 @@
  117.  }
  118.  
  119.  pub struct Sum<T: NodeData> {
  120. -    pre_hash: u64, //I would like these fields to be private if possible.
  121. +    pre_hash: u64,
  122.      pub minus: bool,
  123.      pub constant: T::Constant,
  124.      pub terms: Vec<Container<Product<T>>>,
  125.  }
  126.  
  127.  pub struct Product<T: NodeData> {
  128. -    pre_hash: u64, //I would like these fields to be private if possible.
  129. +    pre_hash: u64,
  130.      pub coefficient: T::Coefficient,
  131.      pub powers: Vec<Container<Power<T>>>,
  132.  }
  133. @@ -46,7 +46,7 @@
  134.  
  135.  pub enum Primitive<T: NodeData> {
  136.      Input(T::Input),
  137. -    SystemVariable(T::SystemVariable),
  138. +    Real(T::Real),
  139.      Parameter(T::Parameter),
  140.      Sigmoid(bool, Container<Sum<T>>),
  141.  }
  142. @@ -94,7 +94,7 @@
  143.      fn eq(&self, other: &Primitive<T>) -> bool {
  144.          match (self, other) {
  145.              (&Primitive::Input(ref a), &Primitive::Input(ref b)) => a == b,
  146. -            (&Primitive::SystemVariable(ref a), &Primitive::SystemVariable(ref b)) => a == b,
  147. +            (&Primitive::Real(ref a), &Primitive::Real(ref b)) => a == b,
  148.              (&Primitive::Parameter(ref a), &Primitive::Parameter(ref b)) => a == b,
  149.              (&Primitive::Sigmoid(minus_a, ref a), &Primitive::Sigmoid(minus_b, ref b)) => {
  150.                  minus_a == minus_b && a == b
  151. @@ -135,8 +135,8 @@
  152.                  "Input".hash(state);
  153.                  a.hash(state)
  154.              }
  155. -            &Primitive::SystemVariable(ref a) => {
  156. -                "SystemVariable".hash(state);
  157. +            &Primitive::Real(ref a) => {
  158. +                "Real".hash(state);
  159.                  a.hash(state)
  160.              }
  161.              &Primitive::Parameter(ref a) => {
  162. @@ -158,10 +158,7 @@
  163.              pre_hash: self.pre_hash.clone(),
  164.              minus: self.minus.clone(),
  165.              constant: self.constant.clone(),
  166. -            terms: self.terms
  167. -                .iter()
  168. -                .map(|p| p.insert(env))
  169. -                .collect(),
  170. +            terms: self.terms.iter().map(|p| p.insert(env)).collect(),
  171.          }
  172.      }
  173.  }
  174. @@ -171,10 +168,7 @@
  175.          Product {
  176.              pre_hash: self.pre_hash.clone(),
  177.              coefficient: self.coefficient.clone(),
  178. -            powers: self.powers
  179. -                .iter()
  180. -                .map(|p| p.insert(env))
  181. -                .collect(),
  182. +            powers: self.powers.iter().map(|p| p.insert(env)).collect(),
  183.          }
  184.      }
  185.  }
  186. @@ -231,7 +225,7 @@
  187.      fn clone(&self) -> Primitive<T> {
  188.          match self {
  189.              &Primitive::Input(ref a) => Primitive::Input(a.clone()),
  190. -            &Primitive::SystemVariable(ref a) => Primitive::SystemVariable(a.clone()),
  191. +            &Primitive::Real(ref a) => Primitive::Real(a.clone()),
  192.              &Primitive::Parameter(ref a) => Primitive::Parameter(a.clone()),
  193.              &Primitive::Sigmoid(minus, ref a) => Primitive::Sigmoid(minus, a.clone()),
  194.          }
  195. diff --git a/rust/adagio/src/node_ordering.rs b/rust/adagio/src/node_ordering.rs
  196. --- a/rust/adagio/src/node_ordering.rs
  197. +++ b/rust/adagio/src/node_ordering.rs
  198. @@ -5,7 +5,7 @@
  199.  #[derive(Eq, Ord, PartialEq, PartialOrd)]
  200.  enum PrimitiveType {
  201.      Input,
  202. -    SystemVariable,
  203. +    Real,
  204.      Parameter,
  205.      Sigmoid,
  206.  }
  207. @@ -14,7 +14,7 @@
  208.      fn simplify(&self) -> PrimitiveType {
  209.          match self {
  210.              &Primitive::Input(_) => PrimitiveType::Input,
  211. -            &Primitive::SystemVariable(_) => PrimitiveType::SystemVariable,
  212. +            &Primitive::Real(_) => PrimitiveType::Real,
  213.              &Primitive::Parameter(_) => PrimitiveType::Parameter,
  214.              &Primitive::Sigmoid(_, _) => PrimitiveType::Sigmoid,
  215.          }
  216. @@ -27,10 +27,7 @@
  217.          match cmp_result {
  218.              Ordering::Equal => {
  219.                  for (self_term, other_term) in
  220. -                    self.terms
  221. -                        .clone()
  222. -                        .into_iter()
  223. -                        .zip(other.terms.clone()) {
  224. +                    self.terms.clone().into_iter().zip(other.terms.clone()) {
  225.                      let cmp_result = self_term.cmp(&other_term);
  226.                      match cmp_result {
  227.                          Ordering::Equal => (),
  228. @@ -94,7 +91,7 @@
  229.      fn cmp(&self, other: &Primitive<T>) -> Ordering {
  230.          match (self, other) {
  231.              (&Primitive::Input(ref a), &Primitive::Input(ref b)) => a.cmp(b),
  232. -            (&Primitive::SystemVariable(ref a), &Primitive::SystemVariable(ref b)) => a.cmp(b),
  233. +            (&Primitive::Real(ref a), &Primitive::Real(ref b)) => a.cmp(b),
  234.              (&Primitive::Parameter(ref a), &Primitive::Parameter(ref b)) => a.cmp(b),
  235.              (&Primitive::Sigmoid(minus_a, ref a), &Primitive::Sigmoid(minus_b, ref b)) => {
  236.                  let cmp_result = a.cmp(b);
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement