Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- -use std::marker::PhantomData;
- +//use std::marker::PhantomData;
- +
- +use std::collections::HashMap;
- use void::Void;
- -trait Input<I, T> {
- +use Node;
- +use node_defs::{Container, NodeData, Sum, Product, Power, Primitive};
- +
- +/*trait Input<I, T> {
- fn get(&self, index: I) -> T where T: Clone;
- }
- @@ -25,4 +30,83 @@
- type Parameter: Parameter<Self::ParameterIndex, T>;
- type ConstantIndex;
- type Constant: Constant<Self::ConstantIndex, T>;
- +}*/
- +
- +struct State {
- + inputs: Vec<f64>,
- + parameters: Vec<f64>,
- }
- +
- +struct Cache<T: NodeData> {
- + sums: HashMap<Container<Sum<T>>, f64>,
- + products: HashMap<Container<Product<T>>, f64>,
- + powers: HashMap<Container<Power<T>>, f64>,
- + primitives: HashMap<Container<Primitive<T>>, f64>,
- +}
- +
- +impl<T: NodeData> Cache<T> {
- + fn new() -> Cache<T> {
- + Cache {
- + sums: HashMap::new(),
- + products: HashMap::new(),
- + powers: HashMap::new(),
- + primitives: HashMap::new(),
- + }
- + }
- +}
- +
- +impl<T: NodeData<Input = usize, Parameter = usize>> Cache<T>
- + where T::Constant: Into<f64>,
- + T::Coefficient: Into<f64>,
- + T::Exponent: Into<i32>
- +{
- + fn eval_sum(&mut self, state: &State, sum: Container<Sum<T>>) -> f64 {
- + if !self.sums.contains_key(&sum.clone()) {
- + let mut total = sum.constant.clone().into();
- + for product in &sum.terms {
- + total += self.eval_product(state, product.clone());
- + }
- + if sum.minus {
- + total = -total;
- + }
- + self.sums.insert(sum.clone(), total);
- + }
- + self.sums[&sum.clone()]
- + }
- +
- + fn eval_product(&mut self, state: &State, product: Container<Product<T>>) -> f64 {
- + if !self.products.contains_key(&product.clone()) {
- + let mut total = product.coefficient.clone().into();
- + for power in &product.powers {
- + total *= self.eval_power(state, power.clone());
- + }
- + self.products.insert(product.clone(), total);
- + }
- + self.products[&product.clone()]
- + }
- +
- + fn eval_power(&mut self, state: &State, power: Container<Power<T>>) -> f64 {
- + if !self.powers.contains_key(&power.clone()) {
- + let result = self.eval_primitive(state, power.primitive.clone())
- + .powi(power.exponent.clone().into());
- + self.powers.insert(power.clone(), result);
- + }
- + self.powers[&power.clone()]
- + }
- +
- + fn eval_primitive(&mut self, state: &State, primitive: Container<Primitive<T>>) -> f64 {
- + if !self.primitives.contains_key(&primitive.clone()) {
- + let result = match primitive.as_ref() {
- + &Primitive::Input(index) => state.inputs[index],
- + &Primitive::Real(_) => unimplemented!(),
- + &Primitive::Parameter(index) => state.parameters[index],
- + &Primitive::Sigmoid(minus, ref sum) => {
- + let tanh = self.eval_sum(state, sum.clone()).tanh();
- + if minus { -tanh } else { tanh }
- + }
- + };
- + self.primitives.insert(primitive.clone(), result);
- + }
- + self.primitives[&primitive.clone()]
- + }
- +}
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement