Advertisement
mwchase

LC:NN upload 7: bye for now

Apr 13th, 2017
225
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Diff 17.39 KB | None | 0 0
  1. diff --git a/rust/adagio/Cargo.toml b/rust/adagio/Cargo.toml
  2. --- a/rust/adagio/Cargo.toml
  3. +++ b/rust/adagio/Cargo.toml
  4. @@ -6,3 +6,4 @@
  5.  [dependencies]
  6.  num = "0.1.37"
  7.  num-rational = "0.1.36"
  8. +void = "1.0.2"
  9. diff --git a/rust/adagio/src/builder.rs b/rust/adagio/src/builder.rs
  10. --- a/rust/adagio/src/builder.rs
  11. +++ b/rust/adagio/src/builder.rs
  12. @@ -1,21 +1,21 @@
  13. -use num_rational::Rational;
  14. +use num_rational::Ratio;
  15.  
  16. -use {Environment, Node};
  17. +use {Environment, LiteralType, Node};
  18.  use node_defs::{Container, NodeData, Sum};
  19. +use node_ops::NodeExpr;
  20.  
  21. -impl<T:NodeData> Environment<T> {
  22. -    fn swlu<I: Node<T>>(&mut self, input: Container<I>) -> Container<Sum<T>> where T::Constant: From<Rational> {
  23. -        let leakiness = self.number(Rational::new(1, 10));
  24. -        let one_half = self.number(Rational::new(1, 2));
  25. -        let one = self.number(Rational::new(1, 1));
  26. -        let minus_one = self.number(Rational::new(-1, 1));
  27. -        let tanh = self.tanh(input.clone());
  28. -        let minus_tanh = self.multiply(minus_one, tanh.clone());
  29. -        let one_plus = self.add(one.clone(), tanh);
  30. -        let one_minus = self.add(one, minus_tanh);
  31. -        let scaled = self.multiply(leakiness, one_minus);
  32. -        let sum = self.add(one_plus, scaled);
  33. -        let half_input = self.multiply(one_half, input);
  34. -        self.multiply(half_input, sum)
  35. +impl<T: NodeData> NodeExpr<T> {
  36. +    pub fn swlu(self) -> NodeExpr<T>
  37. +        where T::Constant: From<Ratio<LiteralType>>
  38. +    {
  39. +        self.clone() * constant((1, 2)) *
  40. +        ((self.clone().tanh() + constant(1)) +
  41. +         (constant(1) - self.clone().tanh()) * constant((1, 10)))
  42.      }
  43. -}
  44. \ No newline at end of file
  45. +}
  46. +
  47. +fn constant<T: NodeData, C: Into<Ratio<LiteralType>>>(constant: C) -> NodeExpr<T>
  48. +    where T::Constant: From<Ratio<LiteralType>>
  49. +{
  50. +    NodeExpr::constant(constant.into())
  51. +}
  52. diff --git a/rust/adagio/src/eval.rs b/rust/adagio/src/eval.rs
  53. --- a/rust/adagio/src/eval.rs
  54. +++ b/rust/adagio/src/eval.rs
  55. @@ -66,11 +66,11 @@
  56.  
  57.      fn eval_primitive(&mut self, state: &State, primitive: Container<Primitive<T>>) -> f64 {
  58.          if !self.primitives.contains_key(&primitive.clone()) {
  59. -            let result = match primitive.as_ref() {
  60. -                &Primitive::Input(index) => state.inputs[index],
  61. -                &Primitive::Real(_) => unimplemented!(),
  62. -                &Primitive::Parameter(index) => state.parameters[index],
  63. -                &Primitive::Sigmoid(minus, ref sum) => {
  64. +            let result = match *primitive.as_ref() {
  65. +                Primitive::Input(index) => state.inputs[index],
  66. +                Primitive::Real(_) => unimplemented!(),
  67. +                Primitive::Parameter(index) => state.parameters[index],
  68. +                Primitive::Sigmoid(minus, ref sum) => {
  69.                      let tanh = self.eval_sum(state, sum.clone()).tanh();
  70.                      if minus { -tanh } else { tanh }
  71.                  }
  72. diff --git a/rust/adagio/src/lib.rs b/rust/adagio/src/lib.rs
  73. --- a/rust/adagio/src/lib.rs
  74. +++ b/rust/adagio/src/lib.rs
  75. @@ -1,5 +1,6 @@
  76.  extern crate num;
  77.  extern crate num_rational;
  78. +extern crate void;
  79.  
  80.  use std::cmp::Ordering;
  81.  use std::collections::{HashMap, HashSet};
  82. @@ -9,10 +10,15 @@
  83.  mod eval;
  84.  mod node_conversions;
  85.  mod node_defs;
  86. +mod node_ops;
  87.  mod node_ordering;
  88.  
  89.  use node_conversions::Summable;
  90.  use node_defs::{Container, container, FancyClone, NodeData, Sum, Product, Power, Primitive};
  91. +use node_ops::{AsExpr, NodeExpr};
  92. +
  93. +///Writing code that's generic over scalars is suffering.
  94. +type LiteralType = i32;
  95.  
  96.  /// Negates a Sum node.
  97.  impl<T: NodeData> Sum<T> {
  98. @@ -77,6 +83,19 @@
  99.  }
  100.  
  101.  impl<T: NodeData> Environment<T> {
  102. +    pub fn new() -> Environment<T> {
  103. +        Environment {
  104. +            sum: HashSet::new(),
  105. +            product: HashSet::new(),
  106. +            power: HashSet::new(),
  107. +            primitive: HashSet::new(),
  108. +            sum_derivative: HashMap::new(),
  109. +            product_derivative: HashMap::new(),
  110. +            power_derivative: HashMap::new(),
  111. +            primitive_derivative: HashMap::new(),
  112. +        }
  113. +    }
  114. +
  115.      /// Makes a canonicalized Sum from the given components.
  116.      pub fn make_sum<C, N: Into<T::Constant>>(&mut self,
  117.                                               minus: bool,
  118. @@ -298,9 +317,7 @@
  119.              }
  120.          }
  121.  
  122. -        last = self.add(last, inner);
  123. -        last = self.add(last, outer);
  124. -        self.add(last, first)
  125. +        self.eval(first.expr() + outer.expr() + inner.expr() + last.expr())
  126.      }
  127.  
  128.      fn _combine_terms(&mut self,
  129. @@ -314,9 +331,18 @@
  130.          let mut left_index = 0;
  131.          let mut right_index = 0;
  132.  
  133. -        while left_index < left_length && right_index < right_length {
  134. +        println!("Vectors equal: {}", left == right);
  135. +        //println!("Left: {:?} right: {:?}", left, right);
  136. +
  137. +        while (left_index < left_length) && (right_index < right_length) {
  138. +            assert!(left_index < left.len(), "Left OOB");
  139. +            assert!(right_index < right.len(), "Right OOB");
  140. +            println!("Right len: {} index: {}", right.len(), right_index);
  141. +            let ref right_power = right[right_length];
  142. +            println!("Right indexing succeeded");
  143. +            println!("Left len: {} index: {} right len: {} index: {}", left.len(), left_index, right.len(), right_index);
  144.              let ref left_power = left[left_index];
  145. -            let ref right_power = right[right_length];
  146. +            println!("Left indexing succeeded");
  147.              match left_power.fuzzy_cmp(right_power.as_ref()) {
  148.                  Ordering::Less => {
  149.                      powers.push(left_power.clone());
  150. @@ -351,9 +377,38 @@
  151.          let sum = sum.as_sum(self);
  152.          self.make_sigmoid(false, sum)
  153.      }
  154. +
  155. +    pub fn eval(&mut self, expr: NodeExpr<T>) -> Container<Sum<T>> {
  156. +        match expr {
  157. +            NodeExpr::Sum(a) => a.insert(self),
  158. +            NodeExpr::Product(a) => a.as_sum(self),
  159. +            NodeExpr::Power(a) => a.as_sum(self),
  160. +            NodeExpr::Primitive(a) => a.as_sum(self),
  161. +            NodeExpr::Constant(a) => self.number(a),
  162. +            NodeExpr::Add(a, b) => {
  163. +                let left = self.eval(*a);
  164. +                let right = self.eval(*b);
  165. +                self.add(left, right)
  166. +            }
  167. +            NodeExpr::Subtract(a, b) => {
  168. +                let left = self.eval(*a);
  169. +                let right = self.eval(*b).negate(self);
  170. +                self.add(left, right)
  171. +            }
  172. +            NodeExpr::Multiply(a, b) => {
  173. +                let left = self.eval(*a);
  174. +                let right = self.eval(*b);
  175. +                self.multiply(left, right)
  176. +            }
  177. +            NodeExpr::Tanh(a) => {
  178. +                let sum = self.eval(*a);
  179. +                self.tanh(sum).as_sum(self)
  180. +            }
  181. +        }
  182. +    }
  183.  }
  184.  
  185. -pub trait Node<T: NodeData>: Summable<T> + FancyClone<T> {
  186. +pub trait Node<T: NodeData>: Summable<T> + FancyClone<T> + AsExpr<T> {
  187.      fn partial_derivative(&self,
  188.                            variable: &T::Parameter,
  189.                            env: &mut Environment<T>)
  190. @@ -497,8 +552,8 @@
  191.                         variable: &T::Parameter,
  192.                         env: &mut Environment<T>)
  193.                         -> Container<Sum<T>> {
  194. -        match self {
  195. -            &Primitive::Sigmoid(minus, ref sum) => {
  196. +        match *self {
  197. +            Primitive::Sigmoid(minus, ref sum) => {
  198.                  let derivative = sum.partial_derivative(variable, env);
  199.                  let self_in_env = self.insert(env);
  200.                  let self_squared = env.secret_make_power(self_in_env, 2);
  201. @@ -506,7 +561,7 @@
  202.                  let sum = env.make_sum(minus, 1, vec![minus_self_squared]);
  203.                  env.multiply(derivative, sum)
  204.              }
  205. -            &Primitive::Parameter(ref p) => env.number(if p == variable { 1 } else { 0 }),
  206. +            Primitive::Parameter(ref p) => env.number(if p == variable { 1 } else { 0 }),
  207.              _ => env.number(0),
  208.          }
  209.      }
  210. @@ -527,6 +582,50 @@
  211.  
  212.  #[cfg(test)]
  213.  mod tests {
  214. +    use num_rational;
  215. +    use void;
  216. +
  217. +    use {Environment, LiteralType, NodeData};
  218. +    use node_ops::AsExpr;
  219.      #[test]
  220. -    fn it_works() {}
  221. +    fn construct_xor() {
  222. +        struct MyNodeData(void::Void);
  223. +
  224. +        impl NodeData for MyNodeData {
  225. +            type Constant = num_rational::Ratio<LiteralType>;
  226. +            type Coefficient = num_rational::Ratio<LiteralType>;
  227. +            type Exponent = LiteralType;
  228. +            type Input = usize;
  229. +            type Real = bool; // Pls no use.
  230. +            type Parameter = usize;
  231. +        }
  232. +
  233. +        let mut env = Environment::<MyNodeData>::new();
  234. +        let mut param_count = 0..;
  235. +
  236. +        let left_input = env.make_input(0);
  237. +        let right_input = env.make_input(1);
  238. +
  239. +        let left_hidden =
  240. +            (left_input.expr() * env.make_parameter(param_count.next().unwrap()).expr() +
  241. +             right_input.expr() * env.make_parameter(param_count.next().unwrap()).expr() +
  242. +             env.make_parameter(param_count.next().unwrap()).expr())
  243. +                    .swlu();
  244. +        let right_hidden =
  245. +            (left_input.expr() * env.make_parameter(param_count.next().unwrap()).expr() +
  246. +             right_input.expr() * env.make_parameter(param_count.next().unwrap()).expr() +
  247. +             env.make_parameter(param_count.next().unwrap()).expr())
  248. +                    .swlu();
  249. +
  250. +        let result = left_hidden * env.make_parameter(param_count.next().unwrap()).expr() +
  251. +                     right_hidden * env.make_parameter(param_count.next().unwrap()).expr() +
  252. +                     env.make_parameter(param_count.next().unwrap()).expr();
  253. +
  254. +        let expected = env.make_input(2);
  255. +        let difference = result.clone() - expected.expr();
  256. +        let error = difference.clone() * difference;
  257. +
  258. +        let result_container = env.eval(result);
  259. +        let error_container = env.eval(error);
  260. +    }
  261.  }
  262. diff --git a/rust/adagio/src/node_conversions.rs b/rust/adagio/src/node_conversions.rs
  263. --- a/rust/adagio/src/node_conversions.rs
  264. +++ b/rust/adagio/src/node_conversions.rs
  265. @@ -44,8 +44,8 @@
  266.  
  267.  impl<T: NodeData> Productable<T> for Primitive<T> {
  268.      fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>> {
  269. -        match self {
  270. -            &Primitive::Sigmoid(true, ref sum) => {
  271. +        match *self {
  272. +            Primitive::Sigmoid(true, ref sum) => {
  273.                  let sigmoid = env.make_sigmoid(false, sum.clone());
  274.                  let power = sigmoid.as_power(env);
  275.                  env.secret_make_product(-1, vec![power])
  276. diff --git a/rust/adagio/src/node_defs.rs b/rust/adagio/src/node_defs.rs
  277. --- a/rust/adagio/src/node_defs.rs
  278. +++ b/rust/adagio/src/node_defs.rs
  279. @@ -5,7 +5,7 @@
  280.  
  281.  use num::{Integer, Signed};
  282.  
  283. -use {Environment, Node};
  284. +use {Environment, LiteralType, Node};
  285.  
  286.  pub type Container<T> = Arc<T>;
  287.  
  288. @@ -14,9 +14,9 @@
  289.  }
  290.  
  291.  pub trait NodeData {
  292. -    type Constant: Hash + Ord + Clone + Signed + Into<Self::Coefficient> + From<i32>;
  293. -    type Coefficient: Hash + Ord + Clone + Signed + Into<Self::Constant> + From<i32>;
  294. -    type Exponent: Hash + Clone + Integer + Into<Self::Constant> + From<i32>;
  295. +    type Constant: Hash + Ord + Clone + Signed + Into<Self::Coefficient> + From<LiteralType>;
  296. +    type Coefficient: Hash + Ord + Clone + Signed + Into<Self::Constant> + From<LiteralType>;
  297. +    type Exponent: Hash + Clone + Integer + Into<Self::Constant> + From<LiteralType>;
  298.      type Input: Hash + Ord + Clone;
  299.      type Real: Hash + Ord + Clone;
  300.      type Parameter: Hash + Ord + Clone;
  301. @@ -130,20 +130,20 @@
  302.  
  303.  impl<T: NodeData> Hash for Primitive<T> {
  304.      fn hash<H: Hasher>(&self, state: &mut H) {
  305. -        match self {
  306. -            &Primitive::Input(ref a) => {
  307. +        match *self {
  308. +            Primitive::Input(ref a) => {
  309.                  "Input".hash(state);
  310.                  a.hash(state)
  311.              }
  312. -            &Primitive::Real(ref a) => {
  313. +            Primitive::Real(ref a) => {
  314.                  "Real".hash(state);
  315.                  a.hash(state)
  316.              }
  317. -            &Primitive::Parameter(ref a) => {
  318. +            Primitive::Parameter(ref a) => {
  319.                  "Parameter".hash(state);
  320.                  a.hash(state)
  321.              }
  322. -            &Primitive::Sigmoid(minus, ref a) => {
  323. +            Primitive::Sigmoid(minus, ref a) => {
  324.                  "Sigmoid".hash(state);
  325.                  minus.hash(state);
  326.                  a.hash(state)
  327. @@ -184,8 +184,8 @@
  328.  
  329.  impl<T: NodeData> FancyClone<T> for Primitive<T> {
  330.      fn fancy_clone(&self, env: &mut Environment<T>) -> Primitive<T> {
  331. -        match self {
  332. -            &Primitive::Sigmoid(minus, ref a) => Primitive::Sigmoid(minus, a.insert(env)),
  333. +        match *self {
  334. +            Primitive::Sigmoid(minus, ref a) => Primitive::Sigmoid(minus, a.insert(env)),
  335.              _ => self.clone(),
  336.          }
  337.      }
  338. @@ -223,11 +223,11 @@
  339.  
  340.  impl<T: NodeData> Clone for Primitive<T> {
  341.      fn clone(&self) -> Primitive<T> {
  342. -        match self {
  343. -            &Primitive::Input(ref a) => Primitive::Input(a.clone()),
  344. -            &Primitive::Real(ref a) => Primitive::Real(a.clone()),
  345. -            &Primitive::Parameter(ref a) => Primitive::Parameter(a.clone()),
  346. -            &Primitive::Sigmoid(minus, ref a) => Primitive::Sigmoid(minus, a.clone()),
  347. +        match *self {
  348. +            Primitive::Input(ref a) => Primitive::Input(a.clone()),
  349. +            Primitive::Real(ref a) => Primitive::Real(a.clone()),
  350. +            Primitive::Parameter(ref a) => Primitive::Parameter(a.clone()),
  351. +            Primitive::Sigmoid(minus, ref a) => Primitive::Sigmoid(minus, a.clone()),
  352.          }
  353.      }
  354.  }
  355. diff --git a/rust/adagio/src/node_ops.rs b/rust/adagio/src/node_ops.rs
  356. new file mode 100644
  357. --- /dev/null
  358. +++ b/rust/adagio/src/node_ops.rs
  359. @@ -0,0 +1,94 @@
  360. +use std::ops::{Add, Mul, Sub};
  361. +
  362. +use Node;
  363. +use node_defs::{NodeData, Sum, Product, Power, Primitive};
  364. +
  365. +pub enum NodeExpr<T: NodeData> {
  366. +    Sum(Sum<T>),
  367. +    Product(Product<T>),
  368. +    Power(Power<T>),
  369. +    Primitive(Primitive<T>),
  370. +    Constant(T::Constant),
  371. +    Add(Box<NodeExpr<T>>, Box<NodeExpr<T>>),
  372. +    Subtract(Box<NodeExpr<T>>, Box<NodeExpr<T>>),
  373. +    Multiply(Box<NodeExpr<T>>, Box<NodeExpr<T>>),
  374. +    Tanh(Box<NodeExpr<T>>),
  375. +}
  376. +
  377. +impl<T: NodeData> Clone for NodeExpr<T> {
  378. +    fn clone(&self) -> NodeExpr<T> {
  379. +        match *self {
  380. +            NodeExpr::Sum(ref a) => NodeExpr::Sum(a.clone()),
  381. +            NodeExpr::Product(ref a) => NodeExpr::Product(a.clone()),
  382. +            NodeExpr::Power(ref a) => NodeExpr::Power(a.clone()),
  383. +            NodeExpr::Primitive(ref a) => NodeExpr::Primitive(a.clone()),
  384. +            NodeExpr::Constant(ref a) => NodeExpr::Constant(a.clone()),
  385. +            NodeExpr::Add(ref a, ref b) => NodeExpr::Add(a.clone(), b.clone()),
  386. +            NodeExpr::Subtract(ref a, ref b) => NodeExpr::Subtract(a.clone(), b.clone()),
  387. +            NodeExpr::Multiply(ref a, ref b) => NodeExpr::Multiply(a.clone(), b.clone()),
  388. +            NodeExpr::Tanh(ref a) => NodeExpr::Tanh(a.clone()),
  389. +        }
  390. +    }
  391. +}
  392. +
  393. +impl<T: NodeData> Add for NodeExpr<T> {
  394. +    type Output = NodeExpr<T>;
  395. +
  396. +    fn add(self, rhs: NodeExpr<T>) -> NodeExpr<T> {
  397. +        NodeExpr::Add(Box::new(self), Box::new(rhs))
  398. +    }
  399. +}
  400. +
  401. +impl<T: NodeData> Sub for NodeExpr<T> {
  402. +    type Output = NodeExpr<T>;
  403. +
  404. +    fn sub(self, rhs: NodeExpr<T>) -> NodeExpr<T> {
  405. +        NodeExpr::Subtract(Box::new(self), Box::new(rhs))
  406. +    }
  407. +}
  408. +
  409. +impl<T: NodeData> Mul for NodeExpr<T> {
  410. +    type Output = NodeExpr<T>;
  411. +
  412. +    fn mul(self, rhs: NodeExpr<T>) -> NodeExpr<T> {
  413. +        NodeExpr::Multiply(Box::new(self), Box::new(rhs))
  414. +    }
  415. +}
  416. +
  417. +pub trait AsExpr<T: NodeData> {
  418. +    fn expr(&self) -> NodeExpr<T>;
  419. +}
  420. +
  421. +impl<T: NodeData> AsExpr<T> for Sum<T> {
  422. +    fn expr(&self) -> NodeExpr<T> {
  423. +        NodeExpr::Sum(self.clone())
  424. +    }
  425. +}
  426. +
  427. +impl<T: NodeData> AsExpr<T> for Product<T> {
  428. +    fn expr(&self) -> NodeExpr<T> {
  429. +        NodeExpr::Product(self.clone())
  430. +    }
  431. +}
  432. +
  433. +impl<T: NodeData> AsExpr<T> for Power<T> {
  434. +    fn expr(&self) -> NodeExpr<T> {
  435. +        NodeExpr::Power(self.clone())
  436. +    }
  437. +}
  438. +
  439. +impl<T: NodeData> AsExpr<T> for Primitive<T> {
  440. +    fn expr(&self) -> NodeExpr<T> {
  441. +        NodeExpr::Primitive(self.clone())
  442. +    }
  443. +}
  444. +
  445. +impl<T: NodeData> NodeExpr<T> {
  446. +    pub fn tanh(self) -> NodeExpr<T> {
  447. +        NodeExpr::Tanh(Box::new(self))
  448. +    }
  449. +
  450. +    pub fn constant<C: Into<T::Constant>>(constant: C) -> NodeExpr<T> {
  451. +        NodeExpr::Constant(constant.into())
  452. +    }
  453. +}
  454. diff --git a/rust/adagio/src/node_ordering.rs b/rust/adagio/src/node_ordering.rs
  455. --- a/rust/adagio/src/node_ordering.rs
  456. +++ b/rust/adagio/src/node_ordering.rs
  457. @@ -12,11 +12,11 @@
  458.  
  459.  impl<T: NodeData> Primitive<T> {
  460.      fn simplify(&self) -> PrimitiveType {
  461. -        match self {
  462. -            &Primitive::Input(_) => PrimitiveType::Input,
  463. -            &Primitive::Real(_) => PrimitiveType::Real,
  464. -            &Primitive::Parameter(_) => PrimitiveType::Parameter,
  465. -            &Primitive::Sigmoid(_, _) => PrimitiveType::Sigmoid,
  466. +        match *self {
  467. +            Primitive::Input(_) => PrimitiveType::Input,
  468. +            Primitive::Real(_) => PrimitiveType::Real,
  469. +            Primitive::Parameter(_) => PrimitiveType::Parameter,
  470. +            Primitive::Sigmoid(_, _) => PrimitiveType::Sigmoid,
  471.          }
  472.      }
  473.  }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement