Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- extern crate num;
- use std::cmp::Ordering;
- use std::collections::HashMap;
- use std::collections::hash_map::{DefaultHasher, Entry};
- use std::hash::{Hash, Hasher};
- use std::sync::Arc;
- use num::{Integer, One, Signed, Zero};
- trait NodeData {
- type Constant: Hash + Eq + Ord + Clone + Signed;
- type Coefficient: Hash + Eq + Ord + Clone + Signed;
- type Exponent: Hash + Clone + Integer;
- type Input: Hash + Eq + Clone;
- type SystemVariable: Hash + Eq + Clone;
- type Parameter: Hash + Eq + Clone;
- }
- struct Sum<T: NodeData> {
- pre_hash: u64,
- minus: bool,
- constant: T::Constant,
- terms: Vec<Arc<Product<T>>>,
- }
- impl<T: NodeData> Sum<T> {
- fn negate(&self, env: &mut Environment<T>) -> Arc<Sum<T>> {
- env.make_sum(!self.minus, &self.constant, &self.terms)
- }
- }
- struct Product<T: NodeData> {
- pre_hash: u64,
- coefficient: T::Coefficient,
- powers: Vec<Arc<Power<T>>>,
- }
- impl<T: NodeData> Product<T> {
- fn negate(&self, env: &mut Environment<T>) -> Arc<Product<T>> {
- env.make_product(&-self.coefficient.clone(), &self.powers).unwrap()
- }
- }
- struct Power<T: NodeData> {
- exponent: T::Exponent,
- primitive: Arc<Primitive<T>>,
- }
- enum Primitive<T: NodeData> {
- Input(T::Input),
- SystemVariable(T::SystemVariable),
- Parameter(T::Parameter),
- Sigmoid(bool, Arc<Sum<T>>),
- }
- type SelfMap<T> = HashMap<T, Arc<T>>;
- struct Environment<T: NodeData> {
- sum: SelfMap<Sum<T>>,
- product: SelfMap<Product<T>>,
- power: SelfMap<Power<T>>,
- primitive: SelfMap<Primitive<T>>, //I could put in the maps for the various partial
- //derivatives, but that's a
- //non-vital optimization.
- }
- impl<T: NodeData> Environment<T> {
- fn make_sum(&mut self,
- minus: bool,
- constant: &T::Constant,
- terms: &Vec<Arc<Product<T>>>)
- -> Arc<Sum<T>> {
- if minus && terms.len() == 0 {
- self.make_sum(false, &-constant.clone(), terms)
- } else {
- Sum::new(minus, constant.clone(), terms.clone()).insert(self)
- }
- }
- fn make_product(&mut self,
- coefficient: &T::Coefficient,
- powers: &Vec<Arc<Power<T>>>)
- -> Option<Arc<Product<T>>> {
- if coefficient.clone() == T::Coefficient::zero() || powers.len() == 0 {
- None
- } else {
- Some(Product::new(coefficient.clone(), powers.clone()).insert(self))
- }
- }
- fn make_power(&mut self,
- primitive: &Arc<Primitive<T>>,
- exponent: &T::Exponent)
- -> Option<Arc<Power<T>>> {
- if exponent.clone() <= T::Exponent::zero() {
- None
- } else {
- Some(Power::new(exponent.clone(), primitive.clone()).insert(self))
- }
- }
- fn make_system_variable(&mut self, index: &T::SystemVariable) -> Arc<Primitive<T>> {
- Primitive::SystemVariable(index.clone()).insert(self)
- }
- fn make_parameter(&mut self, name: &T::Parameter) -> Arc<Primitive<T>> {
- Primitive::Parameter(name.clone()).insert(self)
- }
- fn make_input(&mut self, index: &T::Input) -> Arc<Primitive<T>> {
- Primitive::Input(index.clone()).insert(self)
- }
- fn make_sigmoid(&mut self, minus: bool, sum: &Arc<Sum<T>>) -> Arc<Primitive<T>> {
- if sum.minus || (sum.terms.len() == 0 && sum.constant < T::Constant::zero()) {
- let sum = sum.negate(self);
- self.make_sigmoid(!minus, &sum)
- } else {
- Primitive::Sigmoid(minus, sum.clone()).insert(self)
- }
- }
- //fn number(&mut self, &T::Constant)
- }
- trait Summable<T: NodeData> {
- fn as_sum(&self, env: &mut Environment<T>) -> Arc<Sum<T>>;
- }
- trait Productable<T: NodeData> {
- fn as_product(&self, env: &mut Environment<T>) -> Arc<Product<T>>;
- }
- fn _as_sum<T: NodeData>(productable: &Productable<T>, env: &mut Environment<T>) -> Arc<Sum<T>> {
- let product = productable.as_product(env);
- if product.coefficient < T::Coefficient::zero() {
- let product = product.negate(env);
- env.make_sum(true, &T::Constant::zero(), &vec![product])
- } else {
- env.make_sum(false, &T::Constant::zero(), &vec![product])
- }
- }
- impl<T: NodeData> Summable<T> for Product<T> {
- fn as_sum(&self, env: &mut Environment<T>) -> Arc<Sum<T>> {
- _as_sum(self, env)
- }
- }
- impl<T: NodeData> Summable<T> for Power<T> {
- fn as_sum(&self, env: &mut Environment<T>) -> Arc<Sum<T>> {
- _as_sum(self, env)
- }
- }
- impl<T: NodeData> Summable<T> for Primitive<T> {
- fn as_sum(&self, env: &mut Environment<T>) -> Arc<Sum<T>> {
- _as_sum(self, env)
- }
- }
- trait Node<T: NodeData>: Summable<T> {
- fn partial_derivative(&self, variable: &T::Parameter, env: &mut Environment<T>) -> Arc<Sum<T>>;
- fn fancy_clone(&self, env: &mut Environment<T>) -> Self;
- fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Self>
- where Self: Sized;
- fn insert(&self, env: &mut Environment<T>) -> Arc<Self>
- where Self: Sized + Clone + Eq + Hash
- {
- if !self.get_storage(env).contains_key(self) {
- let key = self.fancy_clone(env);
- let value = key.clone();
- self.get_storage(env).insert(key, Arc::new(value));
- }
- match self.get_storage(env).get(self) {
- Some(node) => node,
- None => unreachable!(),
- }
- .clone()
- }
- }
- impl<T: NodeData> Summable<T> for Sum<T> {
- fn as_sum(&self, env: &mut Environment<T>) -> Arc<Sum<T>> {
- self.insert(env)
- }
- }
- impl<T: NodeData> Node<T> for Sum<T> {
- fn partial_derivative(&self, variable: &T::Parameter, env: &mut Environment<T>) -> Arc<Sum<T>> {
- unimplemented!();
- }
- fn fancy_clone(&self, env: &mut Environment<T>) -> Sum<T> {
- Sum {
- pre_hash: self.pre_hash.clone(),
- minus: self.minus.clone(),
- constant: self.constant.clone(),
- terms: self.terms
- .iter()
- .map(|p| p.insert(env))
- .collect::<Vec<_>>(),
- }
- }
- fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Sum<T>> {
- &mut env.sum
- }
- }
- impl<T: NodeData> Productable<T> for Product<T> {
- fn as_product(&self, env: &mut Environment<T>) -> Arc<Product<T>> {
- self.insert(env)
- }
- }
- impl<T: NodeData> Node<T> for Product<T> {
- fn partial_derivative(&self, variable: &T::Parameter, env: &mut Environment<T>) -> Arc<Sum<T>> {
- unimplemented!();
- }
- fn fancy_clone(&self, env: &mut Environment<T>) -> Product<T> {
- Product {
- pre_hash: self.pre_hash.clone(),
- coefficient: self.coefficient.clone(),
- powers: self.powers
- .iter()
- .map(|p| p.insert(env))
- .collect::<Vec<_>>(),
- }
- }
- fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Product<T>> {
- &mut env.product
- }
- }
- impl<T: NodeData> Productable<T> for Power<T> {
- fn as_product(&self, env: &mut Environment<T>) -> Arc<Product<T>> {
- let power = self.insert(env);
- env.make_product(&T::Coefficient::one(), &vec![power]).unwrap()
- }
- }
- impl<T: NodeData> Node<T> for Power<T> {
- fn partial_derivative(&self, variable: &T::Parameter, env: &mut Environment<T>) -> Arc<Sum<T>> {
- unimplemented!();
- }
- fn fancy_clone(&self, env: &mut Environment<T>) -> Power<T> {
- Power {
- exponent: self.exponent.clone(),
- primitive: self.primitive.insert(env),
- }
- }
- fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Power<T>> {
- &mut env.power
- }
- }
- impl<T: NodeData> Primitive<T> {
- fn as_power(&self, env: &mut Environment<T>) -> Arc<Power<T>> {
- let primitive = self.insert(env);
- env.make_power(&primitive, &T::Exponent::one()).unwrap()
- }
- }
- impl<T: NodeData> Productable<T> for Primitive<T> {
- fn as_product(&self, env: &mut Environment<T>) -> Arc<Product<T>> {
- match self {
- &Primitive::Sigmoid(true, ref sum) => {
- let sigmoid = env.make_sigmoid(false, sum);
- let power = sigmoid.as_power(env);
- env.make_product(&-T::Coefficient::one(), &vec![power]).unwrap()
- }
- _ => {
- let power = self.as_power(env);
- power.as_product(env)
- }
- }
- }
- }
- impl<T: NodeData> Node<T> for Primitive<T> {
- fn partial_derivative(&self, variable: &T::Parameter, env: &mut Environment<T>) -> Arc<Sum<T>> {
- unimplemented!();
- }
- fn fancy_clone(&self, env: &mut Environment<T>) -> Primitive<T> {
- match self {
- &Primitive::Sigmoid(minus, ref a) => Primitive::Sigmoid(minus, a.insert(env)),
- _ => self.clone(),
- }
- }
- fn get_storage<'a>(&self, env: &'a mut Environment<T>) -> &'a mut SelfMap<Primitive<T>> {
- &mut env.primitive
- }
- }
- impl<T: NodeData> Sum<T> {
- fn new(minus: bool, constant: T::Constant, terms: Vec<Arc<Product<T>>>) -> Sum<T> {
- let mut s = DefaultHasher::new();
- minus.hash(&mut s);
- constant.hash(&mut s);
- terms.hash(&mut s);
- Sum {
- pre_hash: s.finish(),
- minus: minus,
- constant: constant,
- terms: terms,
- }
- }
- }
- impl<T: NodeData> Product<T> {
- fn new(coefficient: T::Coefficient, powers: Vec<Arc<Power<T>>>) -> Product<T> {
- let mut s = DefaultHasher::new();
- coefficient.hash(&mut s);
- powers.hash(&mut s);
- Product {
- pre_hash: s.finish(),
- coefficient: coefficient,
- powers: powers,
- }
- }
- }
- impl<T: NodeData> Power<T> {
- fn new(exponent: T::Exponent, primitive: Arc<Primitive<T>>) -> Power<T> {
- Power {
- exponent: exponent,
- primitive: primitive,
- }
- }
- }
- impl<T: NodeData> Ord for Sum<T> {
- fn cmp(&self, other: &Sum<T>) -> Ordering {
- let cmp_result = self.constant.cmp(&other.constant);
- match cmp_result {
- Ordering::Equal => {
- for (self_term, other_term) in
- self.terms
- .clone()
- .into_iter()
- .zip(other.terms.clone()) {
- let cmp_result = self_term.cmp(&other_term);
- match cmp_result {
- Ordering::Equal => (),
- _ => return cmp_result,
- }
- }
- let cmp_result = self.terms.len().cmp(&other.terms.len());
- match cmp_result {
- Ordering::Equal => self.minus.cmp(&other.minus),
- _ => cmp_result,
- }
- }
- _ => cmp_result,
- }
- }
- }
- impl<T: NodeData> Ord for Product<T> {
- fn cmp(&self, other: &Product<T>) -> Ordering {
- let cmp_result = self.fuzzy_cmp(other);
- match cmp_result {
- Ordering::Equal => self.coefficient.cmp(&other.coefficient),
- _ => cmp_result,
- }
- }
- }
- impl<T: NodeData> Product<T> {
- fn fuzzy_cmp(&self, other: &Product<T>) -> Ordering {
- for (self_power, other_power) in
- self.powers
- .clone()
- .into_iter()
- .zip(other.powers.clone()) {
- let cmp_result = self_power.cmp(&other_power);
- match cmp_result {
- Ordering::Equal => (),
- _ => return cmp_result,
- }
- }
- self.powers.len().cmp(&other.powers.len())
- }
- }
- impl<T: NodeData> Ord for Power<T> {
- fn cmp(&self, other: &Power<T>) -> Ordering {
- let cmp_result = self.primitive.cmp(&other.primitive);
- match cmp_result {
- Ordering::Equal => self.exponent.cmp(&other.exponent),
- _ => cmp_result,
- }
- }
- }
- impl<T: NodeData> Ord for Primitive<T> {
- fn cmp(&self, other: &Primitive<T>) -> Ordering {
- unimplemented!()
- }
- }
- impl<T: NodeData> PartialOrd for Sum<T> {
- fn partial_cmp(&self, other: &Sum<T>) -> Option<Ordering> {
- Some(self.cmp(other))
- }
- }
- impl<T: NodeData> PartialOrd for Product<T> {
- fn partial_cmp(&self, other: &Product<T>) -> Option<Ordering> {
- Some(self.cmp(other))
- }
- }
- impl<T: NodeData> PartialOrd for Power<T> {
- fn partial_cmp(&self, other: &Power<T>) -> Option<Ordering> {
- Some(self.cmp(other))
- }
- }
- impl<T: NodeData> PartialOrd for Primitive<T> {
- fn partial_cmp(&self, other: &Primitive<T>) -> Option<Ordering> {
- Some(self.cmp(other))
- }
- }
- impl<T: NodeData> Hash for Sum<T> {
- fn hash<H: Hasher>(&self, state: &mut H) {
- self.pre_hash.hash(state);
- }
- }
- impl<T: NodeData> Hash for Product<T> {
- fn hash<H: Hasher>(&self, state: &mut H) {
- self.pre_hash.hash(state);
- }
- }
- impl<T: NodeData> Hash for Power<T> {
- fn hash<H: Hasher>(&self, state: &mut H) {
- self.exponent.hash(state);
- self.primitive.hash(state);
- }
- }
- impl<T: NodeData> Hash for Primitive<T> {
- fn hash<H: Hasher>(&self, state: &mut H) {
- match self {
- &Primitive::Input(ref a) => {
- "Input".hash(state);
- a.hash(state)
- }
- &Primitive::SystemVariable(ref a) => {
- "SystemVariable".hash(state);
- a.hash(state)
- }
- &Primitive::Parameter(ref a) => {
- "Parameter".hash(state);
- a.hash(state)
- }
- &Primitive::Sigmoid(minus, ref a) => {
- "Sigmoid".hash(state);
- minus.hash(state);
- a.hash(state)
- }
- }
- }
- }
- impl<T: NodeData> PartialEq for Sum<T> {
- fn eq(&self, other: &Sum<T>) -> bool {
- self.minus == other.minus && self.constant == other.constant && self.terms == other.terms
- }
- }
- impl<T: NodeData> PartialEq for Product<T> {
- fn eq(&self, other: &Product<T>) -> bool {
- self.coefficient == other.coefficient && self.powers == other.powers
- }
- }
- impl<T: NodeData> PartialEq for Power<T> {
- fn eq(&self, other: &Power<T>) -> bool {
- self.exponent == other.exponent && self.primitive == other.primitive
- }
- }
- impl<T: NodeData> PartialEq for Primitive<T> {
- fn eq(&self, other: &Primitive<T>) -> bool {
- match (self, other) {
- (&Primitive::Input(ref a), &Primitive::Input(ref b)) => a == b,
- (&Primitive::SystemVariable(ref a), &Primitive::SystemVariable(ref b)) => a == b,
- (&Primitive::Parameter(ref a), &Primitive::Parameter(ref b)) => a == b,
- (&Primitive::Sigmoid(minus_a, ref a), &Primitive::Sigmoid(minus_b, ref b)) => {
- minus_a == minus_b && a == b
- }
- _ => false,
- }
- }
- }
- impl<T: NodeData> Eq for Sum<T> {}
- impl<T: NodeData> Eq for Product<T> {}
- impl<T: NodeData> Eq for Power<T> {}
- impl<T: NodeData> Eq for Primitive<T> {}
- impl<T: NodeData> Clone for Sum<T> {
- fn clone(&self) -> Sum<T> {
- Sum {
- pre_hash: self.pre_hash.clone(),
- minus: self.minus.clone(),
- constant: self.constant.clone(),
- terms: self.terms.clone(),
- }
- }
- }
- impl<T: NodeData> Clone for Product<T> {
- fn clone(&self) -> Product<T> {
- Product {
- pre_hash: self.pre_hash.clone(),
- coefficient: self.coefficient.clone(),
- powers: self.powers.clone(),
- }
- }
- }
- impl<T: NodeData> Clone for Power<T> {
- fn clone(&self) -> Power<T> {
- Power {
- exponent: self.exponent.clone(),
- primitive: self.primitive.clone(),
- }
- }
- }
- impl<T: NodeData> Clone for Primitive<T> {
- fn clone(&self) -> Primitive<T> {
- match self {
- &Primitive::Input(ref a) => Primitive::Input(a.clone()),
- &Primitive::SystemVariable(ref a) => Primitive::SystemVariable(a.clone()),
- &Primitive::Parameter(ref a) => Primitive::Parameter(a.clone()),
- &Primitive::Sigmoid(minus, ref a) => Primitive::Sigmoid(minus, a.clone()),
- }
- }
- }
- #[cfg(test)]
- mod tests {
- #[test]
- fn it_works() {}
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement