Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- diff --git a/rust/adagio/Cargo.toml b/rust/adagio/Cargo.toml
- --- a/rust/adagio/Cargo.toml
- +++ b/rust/adagio/Cargo.toml
- @@ -6,3 +6,4 @@
- [dependencies]
- num = "0.1.37"
- num-rational = "0.1.36"
- +void = "1.0.2"
- diff --git a/rust/adagio/src/builder.rs b/rust/adagio/src/builder.rs
- --- a/rust/adagio/src/builder.rs
- +++ b/rust/adagio/src/builder.rs
- @@ -1,21 +1,21 @@
- -use num_rational::Rational;
- +use num_rational::Ratio;
- -use {Environment, Node};
- +use {Environment, LiteralType, Node};
- use node_defs::{Container, NodeData, Sum};
- +use node_ops::NodeExpr;
- -impl<T:NodeData> Environment<T> {
- - fn swlu<I: Node<T>>(&mut self, input: Container<I>) -> Container<Sum<T>> where T::Constant: From<Rational> {
- - let leakiness = self.number(Rational::new(1, 10));
- - let one_half = self.number(Rational::new(1, 2));
- - let one = self.number(Rational::new(1, 1));
- - let minus_one = self.number(Rational::new(-1, 1));
- - let tanh = self.tanh(input.clone());
- - let minus_tanh = self.multiply(minus_one, tanh.clone());
- - let one_plus = self.add(one.clone(), tanh);
- - let one_minus = self.add(one, minus_tanh);
- - let scaled = self.multiply(leakiness, one_minus);
- - let sum = self.add(one_plus, scaled);
- - let half_input = self.multiply(one_half, input);
- - self.multiply(half_input, sum)
- +impl<T: NodeData> NodeExpr<T> {
- + pub fn swlu(self) -> NodeExpr<T>
- + where T::Constant: From<Ratio<LiteralType>>
- + {
- + self.clone() * constant((1, 2)) *
- + ((self.clone().tanh() + constant(1)) +
- + (constant(1) - self.clone().tanh()) * constant((1, 10)))
- }
- -}
- \ No newline at end of file
- +}
- +
- +fn constant<T: NodeData, C: Into<Ratio<LiteralType>>>(constant: C) -> NodeExpr<T>
- + where T::Constant: From<Ratio<LiteralType>>
- +{
- + NodeExpr::constant(constant.into())
- +}
- diff --git a/rust/adagio/src/eval.rs b/rust/adagio/src/eval.rs
- --- a/rust/adagio/src/eval.rs
- +++ b/rust/adagio/src/eval.rs
- @@ -66,11 +66,11 @@
- fn eval_primitive(&mut self, state: &State, primitive: Container<Primitive<T>>) -> f64 {
- if !self.primitives.contains_key(&primitive.clone()) {
- - let result = match primitive.as_ref() {
- - &Primitive::Input(index) => state.inputs[index],
- - &Primitive::Real(_) => unimplemented!(),
- - &Primitive::Parameter(index) => state.parameters[index],
- - &Primitive::Sigmoid(minus, ref sum) => {
- + let result = match *primitive.as_ref() {
- + Primitive::Input(index) => state.inputs[index],
- + Primitive::Real(_) => unimplemented!(),
- + Primitive::Parameter(index) => state.parameters[index],
- + Primitive::Sigmoid(minus, ref sum) => {
- let tanh = self.eval_sum(state, sum.clone()).tanh();
- if minus { -tanh } else { tanh }
- }
- diff --git a/rust/adagio/src/lib.rs b/rust/adagio/src/lib.rs
- --- a/rust/adagio/src/lib.rs
- +++ b/rust/adagio/src/lib.rs
- @@ -1,5 +1,6 @@
- extern crate num;
- extern crate num_rational;
- +extern crate void;
- use std::cmp::Ordering;
- use std::collections::{HashMap, HashSet};
- @@ -9,10 +10,15 @@
- mod eval;
- mod node_conversions;
- mod node_defs;
- +mod node_ops;
- mod node_ordering;
- use node_conversions::Summable;
- use node_defs::{Container, container, FancyClone, NodeData, Sum, Product, Power, Primitive};
- +use node_ops::{AsExpr, NodeExpr};
- +
- +///Writing code that's generic over scalars is suffering.
- +type LiteralType = i32;
- /// Negates a Sum node.
- impl<T: NodeData> Sum<T> {
- @@ -77,6 +83,19 @@
- }
- impl<T: NodeData> Environment<T> {
- + pub fn new() -> Environment<T> {
- + Environment {
- + sum: HashSet::new(),
- + product: HashSet::new(),
- + power: HashSet::new(),
- + primitive: HashSet::new(),
- + sum_derivative: HashMap::new(),
- + product_derivative: HashMap::new(),
- + power_derivative: HashMap::new(),
- + primitive_derivative: HashMap::new(),
- + }
- + }
- +
- /// Makes a canonicalized Sum from the given components.
- pub fn make_sum<C, N: Into<T::Constant>>(&mut self,
- minus: bool,
- @@ -298,9 +317,7 @@
- }
- }
- - last = self.add(last, inner);
- - last = self.add(last, outer);
- - self.add(last, first)
- + self.eval(first.expr() + outer.expr() + inner.expr() + last.expr())
- }
- fn _combine_terms(&mut self,
- @@ -314,9 +331,18 @@
- let mut left_index = 0;
- let mut right_index = 0;
- - while left_index < left_length && right_index < right_length {
- + println!("Vectors equal: {}", left == right);
- + //println!("Left: {:?} right: {:?}", left, right);
- +
- + while (left_index < left_length) && (right_index < right_length) {
- + assert!(left_index < left.len(), "Left OOB");
- + assert!(right_index < right.len(), "Right OOB");
- + println!("Right len: {} index: {}", right.len(), right_index);
- + let ref right_power = right[right_length];
- + println!("Right indexing succeeded");
- + println!("Left len: {} index: {} right len: {} index: {}", left.len(), left_index, right.len(), right_index);
- let ref left_power = left[left_index];
- - let ref right_power = right[right_length];
- + println!("Left indexing succeeded");
- match left_power.fuzzy_cmp(right_power.as_ref()) {
- Ordering::Less => {
- powers.push(left_power.clone());
- @@ -351,9 +377,38 @@
- let sum = sum.as_sum(self);
- self.make_sigmoid(false, sum)
- }
- +
- + pub fn eval(&mut self, expr: NodeExpr<T>) -> Container<Sum<T>> {
- + match expr {
- + NodeExpr::Sum(a) => a.insert(self),
- + NodeExpr::Product(a) => a.as_sum(self),
- + NodeExpr::Power(a) => a.as_sum(self),
- + NodeExpr::Primitive(a) => a.as_sum(self),
- + NodeExpr::Constant(a) => self.number(a),
- + NodeExpr::Add(a, b) => {
- + let left = self.eval(*a);
- + let right = self.eval(*b);
- + self.add(left, right)
- + }
- + NodeExpr::Subtract(a, b) => {
- + let left = self.eval(*a);
- + let right = self.eval(*b).negate(self);
- + self.add(left, right)
- + }
- + NodeExpr::Multiply(a, b) => {
- + let left = self.eval(*a);
- + let right = self.eval(*b);
- + self.multiply(left, right)
- + }
- + NodeExpr::Tanh(a) => {
- + let sum = self.eval(*a);
- + self.tanh(sum).as_sum(self)
- + }
- + }
- + }
- }
- -pub trait Node<T: NodeData>: Summable<T> + FancyClone<T> {
- +pub trait Node<T: NodeData>: Summable<T> + FancyClone<T> + AsExpr<T> {
- fn partial_derivative(&self,
- variable: &T::Parameter,
- env: &mut Environment<T>)
- @@ -497,8 +552,8 @@
- variable: &T::Parameter,
- env: &mut Environment<T>)
- -> Container<Sum<T>> {
- - match self {
- - &Primitive::Sigmoid(minus, ref sum) => {
- + match *self {
- + Primitive::Sigmoid(minus, ref sum) => {
- let derivative = sum.partial_derivative(variable, env);
- let self_in_env = self.insert(env);
- let self_squared = env.secret_make_power(self_in_env, 2);
- @@ -506,7 +561,7 @@
- let sum = env.make_sum(minus, 1, vec![minus_self_squared]);
- env.multiply(derivative, sum)
- }
- - &Primitive::Parameter(ref p) => env.number(if p == variable { 1 } else { 0 }),
- + Primitive::Parameter(ref p) => env.number(if p == variable { 1 } else { 0 }),
- _ => env.number(0),
- }
- }
- @@ -527,6 +582,50 @@
- #[cfg(test)]
- mod tests {
- + use num_rational;
- + use void;
- +
- + use {Environment, LiteralType, NodeData};
- + use node_ops::AsExpr;
- #[test]
- - fn it_works() {}
- + fn construct_xor() {
- + struct MyNodeData(void::Void);
- +
- + impl NodeData for MyNodeData {
- + type Constant = num_rational::Ratio<LiteralType>;
- + type Coefficient = num_rational::Ratio<LiteralType>;
- + type Exponent = LiteralType;
- + type Input = usize;
- + type Real = bool; // Pls no use.
- + type Parameter = usize;
- + }
- +
- + let mut env = Environment::<MyNodeData>::new();
- + let mut param_count = 0..;
- +
- + let left_input = env.make_input(0);
- + let right_input = env.make_input(1);
- +
- + let left_hidden =
- + (left_input.expr() * env.make_parameter(param_count.next().unwrap()).expr() +
- + right_input.expr() * env.make_parameter(param_count.next().unwrap()).expr() +
- + env.make_parameter(param_count.next().unwrap()).expr())
- + .swlu();
- + let right_hidden =
- + (left_input.expr() * env.make_parameter(param_count.next().unwrap()).expr() +
- + right_input.expr() * env.make_parameter(param_count.next().unwrap()).expr() +
- + env.make_parameter(param_count.next().unwrap()).expr())
- + .swlu();
- +
- + let result = left_hidden * env.make_parameter(param_count.next().unwrap()).expr() +
- + right_hidden * env.make_parameter(param_count.next().unwrap()).expr() +
- + env.make_parameter(param_count.next().unwrap()).expr();
- +
- + let expected = env.make_input(2);
- + let difference = result.clone() - expected.expr();
- + let error = difference.clone() * difference;
- +
- + let result_container = env.eval(result);
- + let error_container = env.eval(error);
- + }
- }
- diff --git a/rust/adagio/src/node_conversions.rs b/rust/adagio/src/node_conversions.rs
- --- a/rust/adagio/src/node_conversions.rs
- +++ b/rust/adagio/src/node_conversions.rs
- @@ -44,8 +44,8 @@
- impl<T: NodeData> Productable<T> for Primitive<T> {
- fn as_product(&self, env: &mut Environment<T>) -> Container<Product<T>> {
- - match self {
- - &Primitive::Sigmoid(true, ref sum) => {
- + match *self {
- + Primitive::Sigmoid(true, ref sum) => {
- let sigmoid = env.make_sigmoid(false, sum.clone());
- let power = sigmoid.as_power(env);
- env.secret_make_product(-1, vec![power])
- diff --git a/rust/adagio/src/node_defs.rs b/rust/adagio/src/node_defs.rs
- --- a/rust/adagio/src/node_defs.rs
- +++ b/rust/adagio/src/node_defs.rs
- @@ -5,7 +5,7 @@
- use num::{Integer, Signed};
- -use {Environment, Node};
- +use {Environment, LiteralType, Node};
- pub type Container<T> = Arc<T>;
- @@ -14,9 +14,9 @@
- }
- pub trait NodeData {
- - type Constant: Hash + Ord + Clone + Signed + Into<Self::Coefficient> + From<i32>;
- - type Coefficient: Hash + Ord + Clone + Signed + Into<Self::Constant> + From<i32>;
- - type Exponent: Hash + Clone + Integer + Into<Self::Constant> + From<i32>;
- + type Constant: Hash + Ord + Clone + Signed + Into<Self::Coefficient> + From<LiteralType>;
- + type Coefficient: Hash + Ord + Clone + Signed + Into<Self::Constant> + From<LiteralType>;
- + type Exponent: Hash + Clone + Integer + Into<Self::Constant> + From<LiteralType>;
- type Input: Hash + Ord + Clone;
- type Real: Hash + Ord + Clone;
- type Parameter: Hash + Ord + Clone;
- @@ -130,20 +130,20 @@
- impl<T: NodeData> Hash for Primitive<T> {
- fn hash<H: Hasher>(&self, state: &mut H) {
- - match self {
- - &Primitive::Input(ref a) => {
- + match *self {
- + Primitive::Input(ref a) => {
- "Input".hash(state);
- a.hash(state)
- }
- - &Primitive::Real(ref a) => {
- + Primitive::Real(ref a) => {
- "Real".hash(state);
- a.hash(state)
- }
- - &Primitive::Parameter(ref a) => {
- + Primitive::Parameter(ref a) => {
- "Parameter".hash(state);
- a.hash(state)
- }
- - &Primitive::Sigmoid(minus, ref a) => {
- + Primitive::Sigmoid(minus, ref a) => {
- "Sigmoid".hash(state);
- minus.hash(state);
- a.hash(state)
- @@ -184,8 +184,8 @@
- impl<T: NodeData> FancyClone<T> for Primitive<T> {
- fn fancy_clone(&self, env: &mut Environment<T>) -> Primitive<T> {
- - match self {
- - &Primitive::Sigmoid(minus, ref a) => Primitive::Sigmoid(minus, a.insert(env)),
- + match *self {
- + Primitive::Sigmoid(minus, ref a) => Primitive::Sigmoid(minus, a.insert(env)),
- _ => self.clone(),
- }
- }
- @@ -223,11 +223,11 @@
- impl<T: NodeData> Clone for Primitive<T> {
- fn clone(&self) -> Primitive<T> {
- - match self {
- - &Primitive::Input(ref a) => Primitive::Input(a.clone()),
- - &Primitive::Real(ref a) => Primitive::Real(a.clone()),
- - &Primitive::Parameter(ref a) => Primitive::Parameter(a.clone()),
- - &Primitive::Sigmoid(minus, ref a) => Primitive::Sigmoid(minus, a.clone()),
- + match *self {
- + Primitive::Input(ref a) => Primitive::Input(a.clone()),
- + Primitive::Real(ref a) => Primitive::Real(a.clone()),
- + Primitive::Parameter(ref a) => Primitive::Parameter(a.clone()),
- + Primitive::Sigmoid(minus, ref a) => Primitive::Sigmoid(minus, a.clone()),
- }
- }
- }
- diff --git a/rust/adagio/src/node_ops.rs b/rust/adagio/src/node_ops.rs
- new file mode 100644
- --- /dev/null
- +++ b/rust/adagio/src/node_ops.rs
- @@ -0,0 +1,94 @@
- +use std::ops::{Add, Mul, Sub};
- +
- +use Node;
- +use node_defs::{NodeData, Sum, Product, Power, Primitive};
- +
- +pub enum NodeExpr<T: NodeData> {
- + Sum(Sum<T>),
- + Product(Product<T>),
- + Power(Power<T>),
- + Primitive(Primitive<T>),
- + Constant(T::Constant),
- + Add(Box<NodeExpr<T>>, Box<NodeExpr<T>>),
- + Subtract(Box<NodeExpr<T>>, Box<NodeExpr<T>>),
- + Multiply(Box<NodeExpr<T>>, Box<NodeExpr<T>>),
- + Tanh(Box<NodeExpr<T>>),
- +}
- +
- +impl<T: NodeData> Clone for NodeExpr<T> {
- + fn clone(&self) -> NodeExpr<T> {
- + match *self {
- + NodeExpr::Sum(ref a) => NodeExpr::Sum(a.clone()),
- + NodeExpr::Product(ref a) => NodeExpr::Product(a.clone()),
- + NodeExpr::Power(ref a) => NodeExpr::Power(a.clone()),
- + NodeExpr::Primitive(ref a) => NodeExpr::Primitive(a.clone()),
- + NodeExpr::Constant(ref a) => NodeExpr::Constant(a.clone()),
- + NodeExpr::Add(ref a, ref b) => NodeExpr::Add(a.clone(), b.clone()),
- + NodeExpr::Subtract(ref a, ref b) => NodeExpr::Subtract(a.clone(), b.clone()),
- + NodeExpr::Multiply(ref a, ref b) => NodeExpr::Multiply(a.clone(), b.clone()),
- + NodeExpr::Tanh(ref a) => NodeExpr::Tanh(a.clone()),
- + }
- + }
- +}
- +
- +impl<T: NodeData> Add for NodeExpr<T> {
- + type Output = NodeExpr<T>;
- +
- + fn add(self, rhs: NodeExpr<T>) -> NodeExpr<T> {
- + NodeExpr::Add(Box::new(self), Box::new(rhs))
- + }
- +}
- +
- +impl<T: NodeData> Sub for NodeExpr<T> {
- + type Output = NodeExpr<T>;
- +
- + fn sub(self, rhs: NodeExpr<T>) -> NodeExpr<T> {
- + NodeExpr::Subtract(Box::new(self), Box::new(rhs))
- + }
- +}
- +
- +impl<T: NodeData> Mul for NodeExpr<T> {
- + type Output = NodeExpr<T>;
- +
- + fn mul(self, rhs: NodeExpr<T>) -> NodeExpr<T> {
- + NodeExpr::Multiply(Box::new(self), Box::new(rhs))
- + }
- +}
- +
- +pub trait AsExpr<T: NodeData> {
- + fn expr(&self) -> NodeExpr<T>;
- +}
- +
- +impl<T: NodeData> AsExpr<T> for Sum<T> {
- + fn expr(&self) -> NodeExpr<T> {
- + NodeExpr::Sum(self.clone())
- + }
- +}
- +
- +impl<T: NodeData> AsExpr<T> for Product<T> {
- + fn expr(&self) -> NodeExpr<T> {
- + NodeExpr::Product(self.clone())
- + }
- +}
- +
- +impl<T: NodeData> AsExpr<T> for Power<T> {
- + fn expr(&self) -> NodeExpr<T> {
- + NodeExpr::Power(self.clone())
- + }
- +}
- +
- +impl<T: NodeData> AsExpr<T> for Primitive<T> {
- + fn expr(&self) -> NodeExpr<T> {
- + NodeExpr::Primitive(self.clone())
- + }
- +}
- +
- +impl<T: NodeData> NodeExpr<T> {
- + pub fn tanh(self) -> NodeExpr<T> {
- + NodeExpr::Tanh(Box::new(self))
- + }
- +
- + pub fn constant<C: Into<T::Constant>>(constant: C) -> NodeExpr<T> {
- + NodeExpr::Constant(constant.into())
- + }
- +}
- diff --git a/rust/adagio/src/node_ordering.rs b/rust/adagio/src/node_ordering.rs
- --- a/rust/adagio/src/node_ordering.rs
- +++ b/rust/adagio/src/node_ordering.rs
- @@ -12,11 +12,11 @@
- impl<T: NodeData> Primitive<T> {
- fn simplify(&self) -> PrimitiveType {
- - match self {
- - &Primitive::Input(_) => PrimitiveType::Input,
- - &Primitive::Real(_) => PrimitiveType::Real,
- - &Primitive::Parameter(_) => PrimitiveType::Parameter,
- - &Primitive::Sigmoid(_, _) => PrimitiveType::Sigmoid,
- + match *self {
- + Primitive::Input(_) => PrimitiveType::Input,
- + Primitive::Real(_) => PrimitiveType::Real,
- + Primitive::Parameter(_) => PrimitiveType::Parameter,
- + Primitive::Sigmoid(_, _) => PrimitiveType::Sigmoid,
- }
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement