Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- diff --git a/rust/adagio/src/eval.rs b/rust/adagio/src/eval.rs
- new file mode 100644
- diff --git a/rust/adagio/src/lib.rs b/rust/adagio/src/lib.rs
- --- a/rust/adagio/src/lib.rs
- +++ b/rust/adagio/src/lib.rs
- @@ -2,30 +2,37 @@
- use std::cmp::Ordering;
- use std::collections::HashMap;
- -use std::hash::{Hash, Hasher};
- +use std::hash::Hash;
- -use num::{One, Zero};
- -
- +mod node_conversions;
- mod node_defs;
- mod node_ordering;
- -use node_defs::{Container, container, NodeData, Sum, Product, Power, Primitive};
- +use node_conversions::Summable;
- +use node_defs::{Container, container, FancyClone, NodeData, Sum, Product, Power, Primitive};
- +/// Negates a Sum node.
- impl<T: NodeData> Sum<T> {
- fn negate(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
- - env.make_sum(!self.minus, &self.constant, &self.terms)
- + env.make_sum(!self.minus,
- + self.constant.clone(),
- + self.terms.iter().cloned())
- }
- }
- +/// Negates a Product node.
- impl<T: NodeData> Product<T> {
- fn negate(&self, env: &mut Environment<T>) -> Container<Product<T>> {
- - env.make_product(&-self.coefficient.clone(), &self.powers).unwrap()
- + env.secret_make_product(-self.coefficient.clone(), self.powers.clone())
- }
- }
- +/// Convenience type, maps an instance of a type to an Arc containing a clone of the instance.
- type SelfMap<T> = HashMap<T, Container<T>>;
- -struct Environment<T: NodeData> {
- +/// A struct that contains canonicalization maps for every node type, as well as maps to
- +/// memoize partial derivatives.
- +pub struct Environment<T: NodeData> {
- sum: SelfMap<Sum<T>>,
- product: SelfMap<Product<T>>,
- power: SelfMap<Power<T>>,
- @@ -37,6 +44,7 @@
- }
- impl<T: NodeData> Sum<T> {
- + /// Returns the value of the sum, if every term were 0.
- fn at_zero(&self) -> T::Constant {
- if self.minus {
- -self.constant.clone()
- @@ -45,108 +53,143 @@
- }
- }
- + /// Returns a Sum that has the same shape as the sum, but equal to constant at zero.
- fn adjust_constant(&self,
- - constant: &T::Constant,
- + constant: T::Constant,
- env: &mut Environment<T>)
- -> Container<Sum<T>> {
- env.make_sum(self.minus,
- - &(if self.minus {
- - -constant.clone()
- - } else {
- - constant.clone()
- - }),
- - &self.terms)
- + (if self.minus { -constant } else { constant }),
- + self.terms.iter().cloned())
- }
- + /// Returns the terms of the sum, negated if the sum was negated.
- fn conditional_negate(&self, env: &mut Environment<T>) -> Vec<Container<Product<T>>> {
- if self.minus {
- - env.negate_products(&self.terms)
- + env.negate_products(self.terms.iter().cloned())
- } else {
- self.terms
- .iter()
- .map(|p| p.insert(env))
- - .collect::<Vec<_>>()
- + .collect()
- }
- }
- }
- impl<T: NodeData> Environment<T> {
- - fn make_sum(&mut self,
- - minus: bool,
- - constant: &T::Constant,
- - terms: &Vec<Container<Product<T>>>)
- - -> Container<Sum<T>> {
- + /// Makes a canonicalized Sum from the given components.
- + pub fn make_sum<C, N: Into<T::Constant>>(&mut self,
- + minus: bool,
- + constant: N,
- + terms: C)
- + -> Container<Sum<T>>
- + where C: IntoIterator<Item = Container<Product<T>>>
- + {
- + let terms = terms.into_iter().collect::<Vec<_>>();
- + let mut minus = minus;
- + let mut constant = constant.into();
- if minus && terms.len() == 0 {
- - self.make_sum(false, &-constant.clone(), terms)
- - } else {
- - Sum::new(minus, constant.clone(), terms.clone()).insert(self)
- + minus = false;
- + constant = -constant;
- }
- + Sum::new(minus, constant, terms).insert(self)
- + }
- +
- + fn secret_make_product<N: Into<T::Coefficient>>(&mut self,
- + coefficient: N,
- + powers: Vec<Container<Power<T>>>)
- + -> Container<Product<T>> {
- + Product::new(coefficient.into(), powers).insert(self)
- }
- - fn make_product(&mut self,
- - coefficient: &T::Constant,
- - powers: &Vec<Container<Power<T>>>)
- - -> Option<Container<Product<T>>> {
- - if coefficient.clone() == T::Constant::zero() || powers.len() == 0 {
- + /// Makes a canonicalized Product from the given components. Because some inputs create
- + /// invalid products, this returns an Option.
- + pub fn make_product<C, N: Into<T::Coefficient>>(&mut self,
- + coefficient: N,
- + powers: C)
- + -> Option<Container<Product<T>>>
- + where C: IntoIterator<Item = Container<Power<T>>>
- + {
- + let coefficient = coefficient.into();
- + if coefficient != 0.into() {
- + let powers = powers.into_iter().collect::<Vec<_>>();
- + if powers.len() != 0 {
- + return Some(self.secret_make_product(coefficient, powers));
- + }
- + }
- + None
- + }
- +
- + fn secret_make_power<N: Into<T::Exponent>>(&mut self,
- + primitive: Container<Primitive<T>>,
- + exponent: N)
- + -> Container<Power<T>> {
- + Power::new(exponent.into(), primitive).insert(self)
- + }
- +
- + /// Makes a canonicalized Power from the given components. Because some inputs create
- + /// invalid powers, this returns an Option.
- + pub fn make_power<N: Into<T::Exponent>>(&mut self,
- + primitive: Container<Primitive<T>>,
- + exponent: N)
- + -> Option<Container<Power<T>>> {
- + let exponent = exponent.into();
- + if exponent <= 0.into() {
- None
- } else {
- - Some(Product::new(coefficient.clone(), powers.clone()).insert(self))
- - }
- - }
- -
- - fn make_power(&mut self,
- - primitive: &Container<Primitive<T>>,
- - exponent: &T::Exponent)
- - -> Option<Container<Power<T>>> {
- - if exponent.clone() <= T::Exponent::zero() {
- - None
- - } else {
- - Some(Power::new(exponent.clone(), primitive.clone()).insert(self))
- + Some(self.secret_make_power(primitive, exponent))
- }
- }
- - fn make_system_variable(&mut self, index: &T::SystemVariable) -> Container<Primitive<T>> {
- - Primitive::SystemVariable(index.clone()).insert(self)
- + /// Makes a canonicalized SystemVariable from the given components.
- + pub fn make_system_variable(&mut self, index: T::SystemVariable) -> Container<Primitive<T>> {
- + Primitive::SystemVariable(index).insert(self)
- }
- - fn make_parameter(&mut self, name: &T::Parameter) -> Container<Primitive<T>> {
- - Primitive::Parameter(name.clone()).insert(self)
- + /// Makes a canonicalized Parameter from the given components.
- + pub fn make_parameter(&mut self, name: T::Parameter) -> Container<Primitive<T>> {
- + Primitive::Parameter(name).insert(self)
- }
- - fn make_input(&mut self, index: &T::Input) -> Container<Primitive<T>> {
- - Primitive::Input(index.clone()).insert(self)
- + /// Makes a canonicalized Input from the given components.
- + pub fn make_input(&mut self, index: T::Input) -> Container<Primitive<T>> {
- + Primitive::Input(index).insert(self)
- }
- - fn make_sigmoid(&mut self, minus: bool, sum: &Container<Sum<T>>) -> Container<Primitive<T>> {
- - if sum.minus || (sum.terms.len() == 0 && sum.constant < T::Constant::zero()) {
- + /// Makes a canonicalized Sigmoid from the given components.
- + pub fn make_sigmoid(&mut self, minus: bool, sum: Container<Sum<T>>) -> Container<Primitive<T>> {
- + if sum.minus || (sum.terms.len() == 0 && sum.constant < 0.into()) {
- let sum = sum.negate(self);
- - self.make_sigmoid(!minus, &sum)
- + self.make_sigmoid(!minus, sum)
- } else {
- - Primitive::Sigmoid(minus, sum.clone()).insert(self)
- + Primitive::Sigmoid(minus, sum).insert(self)
- }
- }
- - fn negate_products(&mut self,
- - terms: &Vec<Container<Product<T>>>)
- - -> Vec<Container<Product<T>>> {
- - terms.iter().map(|p| p.negate(self)).collect::<Vec<_>>()
- - }
- - fn number(&mut self, number: &T::Constant) -> Container<Sum<T>> {
- - self.make_sum(false, number, &vec![])
- + fn negate_products<C>(&mut self, terms: C) -> Vec<Container<Product<T>>>
- + where C: IntoIterator<Item = Container<Product<T>>>
- + {
- + terms.into_iter().map(|p| p.negate(self)).collect()
- }
- - fn add<A: Node<T>, B: Node<T>>(&mut self,
- - left: &Container<A>,
- - right: &Container<B>)
- - -> Container<Sum<T>> {
- + /// Makes a constant Sum.
- + pub fn number<N: Into<T::Constant>>(&mut self, number: N) -> Container<Sum<T>> {
- + self.make_sum(false, number, vec![])
- + }
- +
- + /// Adds two Nodes together, producing a Sum.
- + pub fn add<A: Node<T>, B: Node<T>>(&mut self,
- + left: Container<A>,
- + right: Container<B>)
- + -> Container<Sum<T>> {
- let left_sum = left.as_sum(self);
- let right_sum = right.as_sum(self);
- let constant = left_sum.at_zero() + right_sum.at_zero();
- if left_sum.terms.len() == 0 {
- - return right_sum.adjust_constant(&constant, self);
- + return right_sum.adjust_constant(constant, self);
- }
- if right_sum.terms.len() == 0 {
- - return left_sum.adjust_constant(&constant, self);
- + return left_sum.adjust_constant(constant, self);
- }
- // Begin port of weird logic
- @@ -170,8 +213,9 @@
- Ordering::Equal => {
- let coefficient = left_term.coefficient.clone() +
- right_term.coefficient.clone();
- - if coefficient != T::Constant::zero() {
- - terms.push(self.make_product(&coefficient, &left_term.powers).unwrap());
- + match self.make_product(coefficient, left_term.powers.iter().cloned()) {
- + Some(product) => terms.push(product),
- + None => (),
- }
- left_index += 1;
- right_index += 1;
- @@ -182,40 +226,41 @@
- }
- }
- if left_index == left_length {
- - terms.extend_from_slice(&right_terms[..].split_at(right_index).1);
- + terms.extend_from_slice(&right_terms[right_index..]);
- }
- if right_index == right_length {
- - terms.extend_from_slice(&left_terms[..].split_at(left_index).1);
- + terms.extend_from_slice(&left_terms[left_index..]);
- }
- }
- terms.shrink_to_fit(); // Why not.
- - if terms.len() == 0 || terms[0].coefficient >= T::Constant::zero() {
- - self.make_sum(false, &constant, &terms)
- + if terms.len() == 0 || terms[0].coefficient >= 0.into() {
- + self.make_sum(false, constant, terms)
- } else {
- - let terms = self.negate_products(&terms);
- - self.make_sum(true, &constant, &terms)
- + let terms = self.negate_products(terms);
- + self.make_sum(true, constant, terms)
- }
- }
- - fn multiply<A: Node<T>, B: Node<T>>(&mut self,
- - left: &Container<A>,
- - right: &Container<B>)
- - -> Container<Sum<T>> {
- + /// Multiplies two Nodes together, producing a Sum.
- + pub fn multiply<A: Node<T>, B: Node<T>>(&mut self,
- + left: Container<A>,
- + right: Container<B>)
- + -> Container<Sum<T>> {
- let left_sum = left.as_sum(self);
- let right_sum = right.as_sum(self);
- let minus = left_sum.minus != right_sum.minus;
- let first_constant = left_sum.constant.clone() * right_sum.constant.clone();
- - let first = self.make_sum(minus, &first_constant, &vec![]);
- + let first = self.make_sum(minus, first_constant, vec![]);
- let mut outer_terms = Vec::with_capacity(right_sum.terms.len());
- - if left_sum.constant != T::Constant::zero() {
- + if left_sum.constant != 0.into() {
- for term in &right_sum.terms {
- - let outer_coefficient = left_sum.constant.clone() * term.coefficient.clone();
- - let outer_term = self.make_product(&outer_coefficient, &term.powers).unwrap();
- - if left_sum.constant < T::Constant::zero() {
- + let outer_coefficient = left_sum.constant.clone().into() * term.coefficient.clone();
- + let outer_term = self.secret_make_product(outer_coefficient, term.powers.clone());
- + if left_sum.constant < 0.into() {
- outer_terms.push(outer_term.negate(self));
- } else {
- outer_terms.push(outer_term);
- @@ -223,17 +268,16 @@
- }
- }
- outer_terms.shrink_to_fit();
- - let outer = self.make_sum((left_sum.constant < T::Constant::zero()) != minus,
- - &T::Constant::zero(),
- - &outer_terms);
- + let outer = self.make_sum((left_sum.constant < 0.into()) != minus, 0, outer_terms);
- // Should be possible to turn these into functions.
- let mut inner_terms = Vec::with_capacity(left_sum.terms.len());
- - if right_sum.constant != T::Constant::zero() {
- + if right_sum.constant != 0.into() {
- for term in &left_sum.terms {
- - let inner_coefficient = right_sum.constant.clone() * term.coefficient.clone();
- - let inner_term = self.make_product(&inner_coefficient, &term.powers).unwrap();
- - if right_sum.constant < T::Constant::zero() {
- + let inner_coefficient = right_sum.constant.clone().into() *
- + term.coefficient.clone();
- + let inner_term = self.secret_make_product(inner_coefficient, term.powers.clone());
- + if right_sum.constant < 0.into() {
- inner_terms.push(inner_term.negate(self));
- } else {
- inner_terms.push(inner_term);
- @@ -241,24 +285,22 @@
- }
- }
- inner_terms.shrink_to_fit();
- - let inner = self.make_sum((right_sum.constant < T::Constant::zero()) != minus,
- - &T::Constant::zero(),
- - &inner_terms);
- + let inner = self.make_sum((right_sum.constant < 0.into()) != minus, 0, inner_terms);
- - let mut last = self.number(&T::Constant::zero());
- + let mut last = self.number(0);
- for left_term in &left_sum.terms {
- for right_term in &right_sum.terms {
- - let combined_terms = self._combine_terms(&left_term.powers, &right_term.powers);
- let last_coefficient = left_term.coefficient.clone() *
- right_term.coefficient.clone();
- - let product = self.make_product(&last_coefficient, &combined_terms).unwrap();
- - last = self.add(&last, &product);
- + let combined_terms = self._combine_terms(&left_term.powers, &right_term.powers);
- + let product = self.secret_make_product(last_coefficient, combined_terms);
- + last = self.add(last, product);
- }
- }
- - last = self.add(&last, &inner);
- - last = self.add(&last, &outer);
- - self.add(&last, &first)
- + last = self.add(last, inner);
- + last = self.add(last, outer);
- + self.add(last, first)
- }
- fn _combine_terms(&mut self,
- @@ -275,14 +317,14 @@
- while left_index < left_length && right_index < right_length {
- let ref left_power = left[left_index];
- let ref right_power = right[right_length];
- - match left_power.fuzzy_cmp(&right_power) {
- + match left_power.fuzzy_cmp(right_power.as_ref()) {
- Ordering::Less => {
- powers.push(left_power.clone());
- left_index += 1;
- }
- Ordering::Equal => {
- let exponent = left_power.exponent.clone() + right_power.exponent.clone();
- - powers.push(self.make_power(&left_power.primitive, &exponent).unwrap());
- + powers.push(self.secret_make_power(left_power.primitive.clone(), exponent));
- left_index += 1;
- right_index += 1;
- }
- @@ -292,10 +334,10 @@
- }
- }
- if left_index == left_length {
- - powers.extend_from_slice(&right[..].split_at(right_index).1);
- + powers.extend_from_slice(&right[right_index..]);
- }
- if right_index == right_length {
- - powers.extend_from_slice(&left[..].split_at(left_index).1);
- + powers.extend_from_slice(&left[left_index..]);
- }
- }
- @@ -306,55 +348,37 @@
- }
- }
- -trait Summable<T: NodeData> {
- - fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>>;
- -}
- -
- -trait Productable<T: NodeData> {
- - fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>>;
- -}
- -
- -fn _as_sum<T: NodeData>(productable: &Productable<T>,
- - env: &mut Environment<T>)
- - -> Container<Sum<T>> {
- - let product = productable.as_product(env);
- - if product.coefficient < T::Constant::zero() {
- - let product = product.negate(env);
- - env.make_sum(true, &T::Constant::zero(), &vec![product])
- - } else {
- - env.make_sum(false, &T::Constant::zero(), &vec![product])
- - }
- -}
- -
- -impl<T: NodeData> Summable<T> for Product<T> {
- - fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
- - _as_sum(self, env)
- - }
- -}
- -
- -impl<T: NodeData> Summable<T> for Power<T> {
- - fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
- - _as_sum(self, env)
- - }
- -}
- -
- -impl<T: NodeData> Summable<T> for Primitive<T> {
- - fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
- - _as_sum(self, env)
- - }
- -}
- -
- -trait Node<T: NodeData>: Summable<T> {
- +pub trait Node<T: NodeData>: Summable<T> + FancyClone<T> {
- fn partial_derivative(&self,
- variable: &T::Parameter,
- env: &mut Environment<T>)
- - -> Container<Sum<T>>;
- -
- - fn fancy_clone(&self, env: &mut Environment<T>) -> Self;
- + -> Container<Sum<T>>
- + where Self: Sized + Clone + Eq + Hash
- + {
- + if !self.get_derivative_cache(env).contains_key(&(self.clone(), variable.clone())) {
- + let key = (self.fancy_clone(env), variable.clone());
- + let result = self.derivative_impl(variable, env);
- + self.get_derivative_cache(env).insert(key, result);
- + }
- + match self.get_derivative_cache(env).get(&(self.clone(), variable.clone())) {
- + Some(derivative) => derivative.clone(),
- + None => unreachable!(),
- + }
- + }
- fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Self>
- where Self: Sized;
- + fn get_derivative_cache<'a>(&self,
- + env: &'a mut Environment<T>)
- + -> &'a mut HashMap<(Self, T::Parameter), Container<Sum<T>>>
- + where Self: Sized;
- +
- + fn derivative_impl(&self,
- + variable: &T::Parameter,
- + env: &mut Environment<T>)
- + -> Container<Sum<T>>;
- +
- fn insert(&self, env: &mut Environment<T>) -> Container<Self>
- where Self: Sized + Clone + Eq + Hash
- {
- @@ -364,173 +388,142 @@
- self.get_storage(env).insert(key, container(value));
- }
- match self.get_storage(env).get(self) {
- - Some(node) => node,
- - None => unreachable!(),
- - }
- - .clone()
- - }
- -}
- -
- -impl<T: NodeData> Summable<T> for Sum<T> {
- - fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
- - self.insert(env)
- + Some(node) => node.clone(),
- + None => unreachable!(),
- + }
- }
- }
- impl<T: NodeData> Node<T> for Sum<T> {
- - fn partial_derivative(&self,
- - variable: &T::Parameter,
- - env: &mut Environment<T>)
- - -> Container<Sum<T>> {
- - if !env.sum_derivative.contains_key(&(self.clone(), variable.clone())) {
- - let key = (self.fancy_clone(env), variable.clone());
- - let mut result = env.number(&T::Constant::zero());
- - for term in &self.terms {
- - let partial = term.partial_derivative(variable, env);
- - result = env.add(&result, &partial);
- - }
- - if self.minus {
- - result = result.negate(env);
- - }
- - env.sum_derivative.insert(key, result);
- + fn derivative_impl(&self,
- + variable: &T::Parameter,
- + env: &mut Environment<T>)
- + -> Container<Sum<T>> {
- + let mut result = env.number(0);
- + for term in &self.terms {
- + let partial = term.partial_derivative(variable, env);
- + result = env.add(result, partial);
- }
- - match env.sum_derivative.get(&(self.clone(), variable.clone())) {
- - Some(derivative) => derivative.clone(),
- - None => unreachable!(),
- + if self.minus {
- + result = result.negate(env);
- }
- - }
- -
- - fn fancy_clone(&self, env: &mut Environment<T>) -> Sum<T> {
- - Sum {
- - pre_hash: self.pre_hash.clone(),
- - minus: self.minus.clone(),
- - constant: self.constant.clone(),
- - terms: self.terms
- - .iter()
- - .map(|p| p.insert(env))
- - .collect::<Vec<_>>(),
- - }
- + result
- }
- fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Sum<T>> {
- &mut env.sum
- }
- -}
- -impl<T: NodeData> Productable<T> for Product<T> {
- - fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>> {
- - self.insert(env)
- + fn get_derivative_cache<'a>(&self,
- + env: &'a mut Environment<T>)
- + -> &'a mut HashMap<(Sum<T>, T::Parameter), Container<Sum<T>>>
- + where Self: Sized
- + {
- + &mut env.sum_derivative
- }
- }
- impl<T: NodeData> Node<T> for Product<T> {
- - fn partial_derivative(&self,
- - variable: &T::Parameter,
- - env: &mut Environment<T>)
- - -> Container<Sum<T>> {
- - if !env.product_derivative.contains_key(&(self.clone(), variable.clone())) {
- - let key = (self.fancy_clone(env), variable.clone());
- - let result = if self.powers.len() == 1 {
- - self.powers[0].partial_derivative(variable, env)
- - } else {
- - let mut result = env.number(&T::Constant::zero());
- - unimplemented!()
- - /*for (index, focus) in self.powers.iter().enumerate() {
- - let remainder = env.make_product(T::Constant::one(), )
- - }*/
- - };
- - env.product_derivative.insert(key, result);
- - }
- - match env.product_derivative.get(&(self.clone(), variable.clone())) {
- - Some(derivative) => derivative.clone(),
- - None => unreachable!(),
- - }
- - }
- -
- - fn fancy_clone(&self, env: &mut Environment<T>) -> Product<T> {
- - Product {
- - pre_hash: self.pre_hash.clone(),
- - coefficient: self.coefficient.clone(),
- - powers: self.powers
- - .iter()
- - .map(|p| p.insert(env))
- - .collect::<Vec<_>>(),
- - }
- + fn derivative_impl(&self,
- + variable: &T::Parameter,
- + env: &mut Environment<T>)
- + -> Container<Sum<T>> {
- + let coefficient = env.number(self.coefficient.clone());
- + let result = if self.powers.len() == 1 {
- + self.powers[0].partial_derivative(variable, env)
- + } else {
- + let mut result = env.number(0);
- + for (index, focus) in self.powers.iter().enumerate() {
- + let remainder = env.secret_make_product(1,
- + self.powers[..index]
- + .iter()
- + .chain(self.powers[index + 1..]
- + .iter())
- + .cloned()
- + .collect());
- + let partial_derivative = focus.partial_derivative(variable, env);
- + let product = env.multiply(remainder, partial_derivative);
- + result = env.add(result, product);
- + }
- + result
- + };
- + env.multiply(result, coefficient)
- }
- fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Product<T>> {
- &mut env.product
- }
- -}
- -impl<T: NodeData> Productable<T> for Power<T> {
- - fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>> {
- - let power = self.insert(env);
- - env.make_product(&T::Constant::one(), &vec![power]).unwrap()
- + fn get_derivative_cache<'a>(&self,
- + env: &'a mut Environment<T>)
- + -> &'a mut HashMap<(Product<T>, T::Parameter), Container<Sum<T>>>
- + where Self: Sized
- + {
- + &mut env.product_derivative
- }
- }
- impl<T: NodeData> Node<T> for Power<T> {
- - fn partial_derivative(&self,
- - variable: &T::Parameter,
- - env: &mut Environment<T>)
- - -> Container<Sum<T>> {
- - unimplemented!();
- - }
- -
- - fn fancy_clone(&self, env: &mut Environment<T>) -> Power<T> {
- - Power {
- - exponent: self.exponent.clone(),
- - primitive: self.primitive.insert(env),
- + fn derivative_impl(&self,
- + variable: &T::Parameter,
- + env: &mut Environment<T>)
- + -> Container<Sum<T>> {
- + let derivative = self.primitive.partial_derivative(variable, env);
- + if self.exponent == 1.into() {
- + derivative
- + } else {
- + let coefficient = env.number(self.exponent.clone());
- + let remainder = env.secret_make_power(self.primitive.clone(),
- + self.exponent.clone() - 1.into());
- + let product = env.multiply(derivative, remainder);
- + env.multiply(coefficient, product)
- }
- }
- fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Power<T>> {
- &mut env.power
- }
- -}
- -impl<T: NodeData> Primitive<T> {
- - fn as_power(&self, env: &mut Environment<T>) -> Container<Power<T>> {
- - let primitive = self.insert(env);
- - env.make_power(&primitive, &T::Exponent::one()).unwrap()
- - }
- -}
- -
- -impl<T: NodeData> Productable<T> for Primitive<T> {
- - fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>> {
- - match self {
- - &Primitive::Sigmoid(true, ref sum) => {
- - let sigmoid = env.make_sigmoid(false, sum);
- - let power = sigmoid.as_power(env);
- - env.make_product(&-T::Constant::one(), &vec![power]).unwrap()
- - }
- - _ => {
- - let power = self.as_power(env);
- - power.as_product(env)
- - }
- - }
- + fn get_derivative_cache<'a>(&self,
- + env: &'a mut Environment<T>)
- + -> &'a mut HashMap<(Power<T>, T::Parameter), Container<Sum<T>>>
- + where Self: Sized
- + {
- + &mut env.power_derivative
- }
- }
- impl<T: NodeData> Node<T> for Primitive<T> {
- - fn partial_derivative(&self,
- - variable: &T::Parameter,
- - env: &mut Environment<T>)
- - -> Container<Sum<T>> {
- - unimplemented!();
- - }
- -
- - fn fancy_clone(&self, env: &mut Environment<T>) -> Primitive<T> {
- + fn derivative_impl(&self,
- + variable: &T::Parameter,
- + env: &mut Environment<T>)
- + -> Container<Sum<T>> {
- match self {
- - &Primitive::Sigmoid(minus, ref a) => Primitive::Sigmoid(minus, a.insert(env)),
- - _ => self.clone(),
- + &Primitive::Sigmoid(minus, ref sum) => {
- + let derivative = sum.partial_derivative(variable, env);
- + let self_in_env = self.insert(env);
- + let self_squared = env.secret_make_power(self_in_env, 2);
- + let minus_self_squared = env.secret_make_product(-1, vec![self_squared]);
- + let sum = env.make_sum(minus, 1, vec![minus_self_squared]);
- + env.multiply(derivative, sum)
- + }
- + &Primitive::Parameter(ref p) => env.number(if p == variable { 1 } else { 0 }),
- + _ => env.number(0),
- }
- }
- fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Primitive<T>> {
- &mut env.primitive
- }
- +
- + fn get_derivative_cache<'a>
- + (&self,
- + env: &'a mut Environment<T>)
- + -> &'a mut HashMap<(Primitive<T>, T::Parameter), Container<Sum<T>>>
- + where Self: Sized
- + {
- + &mut env.primitive_derivative
- + }
- }
- #[cfg(test)]
- diff --git a/rust/adagio/src/node_conversions.rs b/rust/adagio/src/node_conversions.rs
- new file mode 100644
- --- /dev/null
- +++ b/rust/adagio/src/node_conversions.rs
- @@ -0,0 +1,83 @@
- +use {Environment, Node};
- +
- +use node_defs::{Container, NodeData, Sum, Product, Power, Primitive};
- +
- +pub trait Summable<T: NodeData> {
- + fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>>;
- +}
- +
- +trait Productable<T: NodeData> {
- + fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>>;
- +}
- +
- +fn _as_sum<T: NodeData, U: Productable<T>>(productable: &U,
- + env: &mut Environment<T>)
- + -> Container<Sum<T>> {
- + let product = productable.as_product(env);
- + if product.coefficient < 0.into() {
- + let product = product.as_ref().negate(env);
- + env.make_sum(true, 0, vec![product])
- + } else {
- + env.make_sum(false, 0, vec![product])
- + }
- +}
- +
- +impl<T: NodeData> Productable<T> for Product<T> {
- + fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>> {
- + self.insert(env)
- + }
- +}
- +
- +impl<T: NodeData> Productable<T> for Power<T> {
- + fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>> {
- + let power = self.insert(env);
- + env.secret_make_product(1, vec![power])
- + }
- +}
- +
- +impl<T: NodeData> Primitive<T> {
- + fn as_power(&self, env: &mut Environment<T>) -> Container<Power<T>> {
- + let primitive = self.insert(env);
- + env.secret_make_power(primitive, 1)
- + }
- +}
- +
- +impl<T: NodeData> Productable<T> for Primitive<T> {
- + fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>> {
- + match self {
- + &Primitive::Sigmoid(true, ref sum) => {
- + let sigmoid = env.make_sigmoid(false, sum.clone());
- + let power = sigmoid.as_power(env);
- + env.secret_make_product(-1, vec![power])
- + }
- + _ => {
- + let power = self.as_power(env);
- + power.as_product(env)
- + }
- + }
- + }
- +}
- +
- +impl<T: NodeData> Summable<T> for Sum<T> {
- + fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
- + self.insert(env)
- + }
- +}
- +
- +impl<T: NodeData> Summable<T> for Product<T> {
- + fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
- + _as_sum(self, env)
- + }
- +}
- +
- +impl<T: NodeData> Summable<T> for Power<T> {
- + fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
- + _as_sum(self, env)
- + }
- +}
- +
- +impl<T: NodeData> Summable<T> for Primitive<T> {
- + fn as_sum(&self, env: &mut Environment<T>) -> Container<Sum<T>> {
- + _as_sum(self, env)
- + }
- +}
- diff --git a/rust/adagio/src/node_defs.rs b/rust/adagio/src/node_defs.rs
- --- a/rust/adagio/src/node_defs.rs
- +++ b/rust/adagio/src/node_defs.rs
- @@ -5,6 +5,8 @@
- use num::{Integer, Signed};
- +use {Environment, Node};
- +
- pub type Container<T> = Arc<T>;
- pub fn container<T>(data: T) -> Container<T> {
- @@ -12,23 +14,28 @@
- }
- pub trait NodeData {
- - type Constant: Hash + Ord + Clone + Signed;
- - type Exponent: Hash + Clone + Integer;
- + type Constant: Hash + Ord + Clone + Signed + Into<Self::Coefficient> + From<i32>;
- + type Coefficient: Hash + Ord + Clone + Signed + Into<Self::Constant> + From<i32>;
- + type Exponent: Hash + Clone + Integer + Into<Self::Constant> + From<i32>;
- type Input: Hash + Ord + Clone;
- type SystemVariable: Hash + Ord + Clone;
- type Parameter: Hash + Ord + Clone;
- }
- +pub trait FancyClone<T: NodeData> {
- + fn fancy_clone(&self, env: &mut Environment<T>) -> Self;
- +}
- +
- pub struct Sum<T: NodeData> {
- - pub pre_hash: u64, //I would like these fields to be private if possible.
- + pre_hash: u64, //I would like these fields to be private if possible.
- pub minus: bool,
- pub constant: T::Constant,
- pub terms: Vec<Container<Product<T>>>,
- }
- pub struct Product<T: NodeData> {
- - pub pre_hash: u64, //I would like these fields to be private if possible.
- - pub coefficient: T::Constant,
- + pre_hash: u64, //I would like these fields to be private if possible.
- + pub coefficient: T::Coefficient,
- pub powers: Vec<Container<Power<T>>>,
- }
- @@ -145,6 +152,51 @@
- }
- }
- +impl<T: NodeData> FancyClone<T> for Sum<T> {
- + fn fancy_clone(&self, env: &mut Environment<T>) -> Sum<T> {
- + Sum {
- + pre_hash: self.pre_hash.clone(),
- + minus: self.minus.clone(),
- + constant: self.constant.clone(),
- + terms: self.terms
- + .iter()
- + .map(|p| p.insert(env))
- + .collect(),
- + }
- + }
- +}
- +
- +impl<T: NodeData> FancyClone<T> for Product<T> {
- + fn fancy_clone(&self, env: &mut Environment<T>) -> Product<T> {
- + Product {
- + pre_hash: self.pre_hash.clone(),
- + coefficient: self.coefficient.clone(),
- + powers: self.powers
- + .iter()
- + .map(|p| p.insert(env))
- + .collect(),
- + }
- + }
- +}
- +
- +impl<T: NodeData> FancyClone<T> for Power<T> {
- + fn fancy_clone(&self, env: &mut Environment<T>) -> Power<T> {
- + Power {
- + exponent: self.exponent.clone(),
- + primitive: self.primitive.insert(env),
- + }
- + }
- +}
- +
- +impl<T: NodeData> FancyClone<T> for Primitive<T> {
- + fn fancy_clone(&self, env: &mut Environment<T>) -> Primitive<T> {
- + match self {
- + &Primitive::Sigmoid(minus, ref a) => Primitive::Sigmoid(minus, a.insert(env)),
- + _ => self.clone(),
- + }
- + }
- +}
- +
- impl<T: NodeData> Clone for Sum<T> {
- fn clone(&self) -> Sum<T> {
- Sum {
- @@ -202,7 +254,7 @@
- }
- impl<T: NodeData> Product<T> {
- - pub fn new(coefficient: T::Constant, powers: Vec<Container<Power<T>>>) -> Product<T> {
- + pub fn new(coefficient: T::Coefficient, powers: Vec<Container<Power<T>>>) -> Product<T> {
- let mut s = DefaultHasher::new();
- coefficient.hash(&mut s);
- powers.hash(&mut s);
- diff --git a/rust/adagio/src/node_ordering.rs b/rust/adagio/src/node_ordering.rs
- --- a/rust/adagio/src/node_ordering.rs
- +++ b/rust/adagio/src/node_ordering.rs
- @@ -1,6 +1,6 @@
- use std::cmp::Ordering;
- -use node_defs::{Container, NodeData, Sum, Product, Power, Primitive};
- +use node_defs::{NodeData, Sum, Product, Power, Primitive};
- #[derive(Eq, Ord, PartialEq, PartialOrd)]
- enum PrimitiveType {
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement