Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- diff --git a/rust/adagio/src/eval.rs b/rust/adagio/src/eval.rs
- --- a/rust/adagio/src/eval.rs
- +++ b/rust/adagio/src/eval.rs
- @@ -0,0 +1,28 @@
- +use std::marker::PhantomData;
- +
- +use void::Void;
- +
- +trait Input<I, T> {
- + fn get(&self, index: I) -> T where T: Clone;
- +}
- +
- +trait Parameter<I, T> {
- + fn get(&self, index: I) -> T where T: Clone;
- +}
- +
- +trait Constant<I, T> {
- + fn get(&self, index: I) -> T where T: Clone;
- +}
- +
- +struct InputVec<T> (Vec<T>);
- +struct ParamVec<T> (Vec<T>);
- +struct UninhabitableConstant<T> (Void, PhantomData<T>);
- +
- +trait Eval<T> {
- + type InputIndex;
- + type Input: Input<Self::InputIndex, T>;
- + type ParameterIndex;
- + type Parameter: Parameter<Self::ParameterIndex, T>;
- + type ConstantIndex;
- + type Constant: Constant<Self::ConstantIndex, T>;
- +}
- diff --git a/rust/adagio/src/lib.rs b/rust/adagio/src/lib.rs
- --- a/rust/adagio/src/lib.rs
- +++ b/rust/adagio/src/lib.rs
- @@ -1,9 +1,11 @@
- extern crate num;
- +extern crate void;
- use std::cmp::Ordering;
- use std::collections::HashMap;
- use std::hash::Hash;
- +mod eval;
- mod node_conversions;
- mod node_defs;
- mod node_ordering;
- @@ -68,10 +70,7 @@
- if self.minus {
- env.negate_products(self.terms.iter().cloned())
- } else {
- - self.terms
- - .iter()
- - .map(|p| p.insert(env))
- - .collect()
- + self.terms.iter().map(|p| p.insert(env)).collect()
- }
- }
- }
- @@ -141,9 +140,9 @@
- }
- }
- - /// Makes a canonicalized SystemVariable from the given components.
- - pub fn make_system_variable(&mut self, index: T::SystemVariable) -> Container<Primitive<T>> {
- - Primitive::SystemVariable(index).insert(self)
- + /// Makes a canonicalized Real from the given components.
- + pub fn make_real(&mut self, index: T::Real) -> Container<Primitive<T>> {
- + Primitive::Real(index).insert(self)
- }
- /// Makes a canonicalized Parameter from the given components.
- @@ -355,15 +354,13 @@
- -> Container<Sum<T>>
- where Self: Sized + Clone + Eq + Hash
- {
- - if !self.get_derivative_cache(env).contains_key(&(self.clone(), variable.clone())) {
- + if !self.get_derivative_cache(env)
- + .contains_key(&(self.clone(), variable.clone())) {
- let key = (self.fancy_clone(env), variable.clone());
- let result = self.derivative_impl(variable, env);
- self.get_derivative_cache(env).insert(key, result);
- }
- - match self.get_derivative_cache(env).get(&(self.clone(), variable.clone())) {
- - Some(derivative) => derivative.clone(),
- - None => unreachable!(),
- - }
- + self.get_derivative_cache(env)[&(self.clone(), variable.clone())].clone()
- }
- fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Self>
- @@ -387,10 +384,7 @@
- let value = key.clone();
- self.get_storage(env).insert(key, container(value));
- }
- - match self.get_storage(env).get(self) {
- - Some(node) => node.clone(),
- - None => unreachable!(),
- - }
- + self.get_storage(env)[self].clone()
- }
- }
- diff --git a/rust/adagio/src/node_defs.rs b/rust/adagio/src/node_defs.rs
- --- a/rust/adagio/src/node_defs.rs
- +++ b/rust/adagio/src/node_defs.rs
- @@ -18,7 +18,7 @@
- type Coefficient: Hash + Ord + Clone + Signed + Into<Self::Constant> + From<i32>;
- type Exponent: Hash + Clone + Integer + Into<Self::Constant> + From<i32>;
- type Input: Hash + Ord + Clone;
- - type SystemVariable: Hash + Ord + Clone;
- + type Real: Hash + Ord + Clone;
- type Parameter: Hash + Ord + Clone;
- }
- @@ -27,14 +27,14 @@
- }
- pub struct Sum<T: NodeData> {
- - pre_hash: u64, //I would like these fields to be private if possible.
- + pre_hash: u64,
- pub minus: bool,
- pub constant: T::Constant,
- pub terms: Vec<Container<Product<T>>>,
- }
- pub struct Product<T: NodeData> {
- - pre_hash: u64, //I would like these fields to be private if possible.
- + pre_hash: u64,
- pub coefficient: T::Coefficient,
- pub powers: Vec<Container<Power<T>>>,
- }
- @@ -46,7 +46,7 @@
- pub enum Primitive<T: NodeData> {
- Input(T::Input),
- - SystemVariable(T::SystemVariable),
- + Real(T::Real),
- Parameter(T::Parameter),
- Sigmoid(bool, Container<Sum<T>>),
- }
- @@ -94,7 +94,7 @@
- fn eq(&self, other: &Primitive<T>) -> bool {
- match (self, other) {
- (&Primitive::Input(ref a), &Primitive::Input(ref b)) => a == b,
- - (&Primitive::SystemVariable(ref a), &Primitive::SystemVariable(ref b)) => a == b,
- + (&Primitive::Real(ref a), &Primitive::Real(ref b)) => a == b,
- (&Primitive::Parameter(ref a), &Primitive::Parameter(ref b)) => a == b,
- (&Primitive::Sigmoid(minus_a, ref a), &Primitive::Sigmoid(minus_b, ref b)) => {
- minus_a == minus_b && a == b
- @@ -135,8 +135,8 @@
- "Input".hash(state);
- a.hash(state)
- }
- - &Primitive::SystemVariable(ref a) => {
- - "SystemVariable".hash(state);
- + &Primitive::Real(ref a) => {
- + "Real".hash(state);
- a.hash(state)
- }
- &Primitive::Parameter(ref a) => {
- @@ -158,10 +158,7 @@
- pre_hash: self.pre_hash.clone(),
- minus: self.minus.clone(),
- constant: self.constant.clone(),
- - terms: self.terms
- - .iter()
- - .map(|p| p.insert(env))
- - .collect(),
- + terms: self.terms.iter().map(|p| p.insert(env)).collect(),
- }
- }
- }
- @@ -171,10 +168,7 @@
- Product {
- pre_hash: self.pre_hash.clone(),
- coefficient: self.coefficient.clone(),
- - powers: self.powers
- - .iter()
- - .map(|p| p.insert(env))
- - .collect(),
- + powers: self.powers.iter().map(|p| p.insert(env)).collect(),
- }
- }
- }
- @@ -231,7 +225,7 @@
- fn clone(&self) -> Primitive<T> {
- match self {
- &Primitive::Input(ref a) => Primitive::Input(a.clone()),
- - &Primitive::SystemVariable(ref a) => Primitive::SystemVariable(a.clone()),
- + &Primitive::Real(ref a) => Primitive::Real(a.clone()),
- &Primitive::Parameter(ref a) => Primitive::Parameter(a.clone()),
- &Primitive::Sigmoid(minus, ref a) => Primitive::Sigmoid(minus, a.clone()),
- }
- diff --git a/rust/adagio/src/node_ordering.rs b/rust/adagio/src/node_ordering.rs
- --- a/rust/adagio/src/node_ordering.rs
- +++ b/rust/adagio/src/node_ordering.rs
- @@ -5,7 +5,7 @@
- #[derive(Eq, Ord, PartialEq, PartialOrd)]
- enum PrimitiveType {
- Input,
- - SystemVariable,
- + Real,
- Parameter,
- Sigmoid,
- }
- @@ -14,7 +14,7 @@
- fn simplify(&self) -> PrimitiveType {
- match self {
- &Primitive::Input(_) => PrimitiveType::Input,
- - &Primitive::SystemVariable(_) => PrimitiveType::SystemVariable,
- + &Primitive::Real(_) => PrimitiveType::Real,
- &Primitive::Parameter(_) => PrimitiveType::Parameter,
- &Primitive::Sigmoid(_, _) => PrimitiveType::Sigmoid,
- }
- @@ -27,10 +27,7 @@
- match cmp_result {
- Ordering::Equal => {
- for (self_term, other_term) in
- - self.terms
- - .clone()
- - .into_iter()
- - .zip(other.terms.clone()) {
- + self.terms.clone().into_iter().zip(other.terms.clone()) {
- let cmp_result = self_term.cmp(&other_term);
- match cmp_result {
- Ordering::Equal => (),
- @@ -94,7 +91,7 @@
- fn cmp(&self, other: &Primitive<T>) -> Ordering {
- match (self, other) {
- (&Primitive::Input(ref a), &Primitive::Input(ref b)) => a.cmp(b),
- - (&Primitive::SystemVariable(ref a), &Primitive::SystemVariable(ref b)) => a.cmp(b),
- + (&Primitive::Real(ref a), &Primitive::Real(ref b)) => a.cmp(b),
- (&Primitive::Parameter(ref a), &Primitive::Parameter(ref b)) => a.cmp(b),
- (&Primitive::Sigmoid(minus_a, ref a), &Primitive::Sigmoid(minus_b, ref b)) => {
- let cmp_result = a.cmp(b);
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement