Advertisement
mwchase

LC:NN upload 6: naive implementation

Apr 9th, 2017
114
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Diff 3.30 KB | None | 0 0
  1. -use std::marker::PhantomData;
  2. +//use std::marker::PhantomData;
  3. +
  4. +use std::collections::HashMap;
  5.  
  6.  use void::Void;
  7.  
  8. -trait Input<I, T> {
  9. +use Node;
  10. +use node_defs::{Container, NodeData, Sum, Product, Power, Primitive};
  11. +
  12. +/*trait Input<I, T> {
  13.      fn get(&self, index: I) -> T where T: Clone;
  14.  }
  15.  
  16. @@ -25,4 +30,83 @@
  17.      type Parameter: Parameter<Self::ParameterIndex, T>;
  18.      type ConstantIndex;
  19.      type Constant: Constant<Self::ConstantIndex, T>;
  20. +}*/
  21. +
  22. +struct State {
  23. +    inputs: Vec<f64>,
  24. +    parameters: Vec<f64>,
  25.  }
  26. +
  27. +struct Cache<T: NodeData> {
  28. +    sums: HashMap<Container<Sum<T>>, f64>,
  29. +    products: HashMap<Container<Product<T>>, f64>,
  30. +    powers: HashMap<Container<Power<T>>, f64>,
  31. +    primitives: HashMap<Container<Primitive<T>>, f64>,
  32. +}
  33. +
  34. +impl<T: NodeData> Cache<T> {
  35. +    fn new() -> Cache<T> {
  36. +        Cache {
  37. +            sums: HashMap::new(),
  38. +            products: HashMap::new(),
  39. +            powers: HashMap::new(),
  40. +            primitives: HashMap::new(),
  41. +        }
  42. +    }
  43. +}
  44. +
  45. +impl<T: NodeData<Input = usize, Parameter = usize>> Cache<T>
  46. +    where T::Constant: Into<f64>,
  47. +          T::Coefficient: Into<f64>,
  48. +          T::Exponent: Into<i32>
  49. +{
  50. +    fn eval_sum(&mut self, state: &State, sum: Container<Sum<T>>) -> f64 {
  51. +        if !self.sums.contains_key(&sum.clone()) {
  52. +            let mut total = sum.constant.clone().into();
  53. +            for product in &sum.terms {
  54. +                total += self.eval_product(state, product.clone());
  55. +            }
  56. +            if sum.minus {
  57. +                total = -total;
  58. +            }
  59. +            self.sums.insert(sum.clone(), total);
  60. +        }
  61. +        self.sums[&sum.clone()]
  62. +    }
  63. +
  64. +    fn eval_product(&mut self, state: &State, product: Container<Product<T>>) -> f64 {
  65. +        if !self.products.contains_key(&product.clone()) {
  66. +            let mut total = product.coefficient.clone().into();
  67. +            for power in &product.powers {
  68. +                total *= self.eval_power(state, power.clone());
  69. +            }
  70. +            self.products.insert(product.clone(), total);
  71. +        }
  72. +        self.products[&product.clone()]
  73. +    }
  74. +
  75. +    fn eval_power(&mut self, state: &State, power: Container<Power<T>>) -> f64 {
  76. +        if !self.powers.contains_key(&power.clone()) {
  77. +            let result = self.eval_primitive(state, power.primitive.clone())
  78. +                .powi(power.exponent.clone().into());
  79. +            self.powers.insert(power.clone(), result);
  80. +        }
  81. +        self.powers[&power.clone()]
  82. +    }
  83. +
  84. +    fn eval_primitive(&mut self, state: &State, primitive: Container<Primitive<T>>) -> f64 {
  85. +        if !self.primitives.contains_key(&primitive.clone()) {
  86. +            let result = match primitive.as_ref() {
  87. +                &Primitive::Input(index) => state.inputs[index],
  88. +                &Primitive::Real(_) => unimplemented!(),
  89. +                &Primitive::Parameter(index) => state.parameters[index],
  90. +                &Primitive::Sigmoid(minus, ref sum) => {
  91. +                    let tanh = self.eval_sum(state, sum.clone()).tanh();
  92. +                    if minus { -tanh } else { tanh }
  93. +                }
  94. +            };
  95. +            self.primitives.insert(primitive.clone(), result);
  96. +        }
  97. +        self.primitives[&primitive.clone()]
  98. +    }
  99. +}
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement