hedgefund

rust_autograd_02

Jan 16th, 2025
40
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Rust 4.68 KB | Source Code | 0 0
  1. use std::time::{Instant, SystemTime, UNIX_EPOCH};
  2. use rand::prelude::*;
  3.  
  4. #[derive(Debug, Clone, Copy)]
  5. struct NaiveVar {
  6.     val: f64,
  7.     grad: f64,
  8. }
  9.  
  10. #[derive(Clone, Copy)]
  11. enum OpType {
  12.     Sum,
  13.     Prod,
  14.     Softplus,
  15. }
  16.  
  17. struct Operation {
  18.     op_type: OpType,
  19.     input_indices: [usize; 2], // fixed-size array for input indices
  20.     n_inputs: usize,
  21.     output_index: usize,
  22. }
  23.  
  24. struct NaiveTape {
  25.     ops: Vec<Operation>,
  26.     vars: Vec<NaiveVar>,
  27. }
  28.  
  29. impl NaiveTape {
  30.     fn init() -> Self {
  31.         NaiveTape {
  32.             ops: Vec::with_capacity(16),
  33.             vars: Vec::with_capacity(16),
  34.         }
  35.     }
  36.  
  37.     fn create_var(&mut self, val: f64) -> usize {
  38.         let id = self.vars.len();
  39.         self.vars.push(NaiveVar { val, grad: 0.0 });
  40.         id
  41.     }
  42.  
  43.     fn sum(&mut self, input_indices: &[usize]) -> usize {
  44.         let sum_val: f64 = input_indices.iter().map(|&id| self.vars[id].val).sum();
  45.         let output_index = self.create_var(sum_val);
  46.  
  47.         self.ops.push(Operation {
  48.             op_type: OpType::Sum,
  49.             input_indices: {
  50.                 let mut arr = [0; 2];
  51.                 arr[..input_indices.len()].copy_from_slice(input_indices);
  52.                 arr
  53.             },
  54.             n_inputs: input_indices.len(),
  55.             output_index,
  56.         });
  57.         output_index
  58.     }
  59.  
  60.     fn prod(&mut self, input1_idx: usize, input2_idx: usize) -> usize {
  61.         let prod_val = self.vars[input1_idx].val * self.vars[input2_idx].val;
  62.         let output_index = self.create_var(prod_val);
  63.  
  64.         self.ops.push(Operation {
  65.             op_type: OpType::Prod,
  66.             input_indices: [input1_idx, input2_idx],
  67.             n_inputs: 2, // fixed number of inputs for Prod
  68.             output_index,
  69.         });
  70.         output_index
  71.     }
  72.  
  73.     fn softplus(&mut self, input_idx: usize) -> usize {
  74.         let softplus_val = (self.vars[input_idx].val.exp() + 1.0).ln();
  75.         let output_index = self.create_var(softplus_val);
  76.  
  77.         self.ops.push(Operation {
  78.             op_type: OpType::Softplus,
  79.             input_indices: [input_idx, 0], // not using the second index
  80.             n_inputs: 1,
  81.             output_index,
  82.         });
  83.         output_index
  84.     }
  85.  
  86.     fn backward(&mut self, output_idx: usize) {
  87.         self.vars[output_idx].grad = 1.0;
  88.  
  89.         for op in self.ops.iter().rev() {
  90.             let output_grad = self.vars[op.output_index].grad;
  91.  
  92.             match op.op_type {
  93.                 OpType::Sum => {
  94.                     for &input_idx in &op.input_indices[..op.n_inputs] {
  95.                         self.vars[input_idx].grad += output_grad;
  96.                     }
  97.                 }
  98.                 OpType::Prod => {
  99.                     let input1_val = self.vars[op.input_indices[0]].val;
  100.                     let input2_val = self.vars[op.input_indices[1]].val;
  101.                     self.vars[op.input_indices[0]].grad += input2_val * output_grad;
  102.                     self.vars[op.input_indices[1]].grad += input1_val * output_grad;
  103.                 }
  104.                 OpType::Softplus => {
  105.                     let input_val = self.vars[op.input_indices[0]].val;
  106.                     let exp_val = (-input_val).exp();
  107.                     self.vars[op.input_indices[0]].grad += output_grad / (1.0 + exp_val);
  108.                 }
  109.             }
  110.         }
  111.     }
  112. }
  113.  
  114. fn seed_rng() -> impl Rng {
  115.     let now = SystemTime::now();
  116.     let since_epoch = now
  117.         .duration_since(UNIX_EPOCH)
  118.         .expect("Time went backwards");
  119.     let seed = since_epoch.as_nanos() as u64;
  120.  
  121.     rand::rngs::StdRng::seed_from_u64(seed)
  122. }
  123.  
  124. fn main() {
  125.     let iterations: usize = 1000000;
  126.     let mut rng = seed_rng();
  127.  
  128.     let start = Instant::now();
  129.     for i in 0..iterations {
  130.         let mut tape = NaiveTape::init();
  131.  
  132.         let var1_idx = tape.create_var(rng.gen::<f64>());
  133.         let var2_idx = tape.create_var(rng.gen::<f64>());
  134.  
  135.         let sum_idx = tape.sum(&[var1_idx, var2_idx]);
  136.         let prod_idx = tape.prod(sum_idx, sum_idx);
  137.         let softplus_idx = tape.softplus(prod_idx);
  138.  
  139.         tape.backward(softplus_idx);
  140.  
  141.         if i == iterations - 1 {
  142.             println!("sum_var val: {}", tape.vars[sum_idx].val);
  143.             println!("prod_var val: {}", tape.vars[prod_idx].val);
  144.             println!("softplus_var val: {}", tape.vars[softplus_idx].val);
  145.             println!("sum_var grad: {}", tape.vars[sum_idx].grad);
  146.             println!("var1 grad: {}", tape.vars[var1_idx].grad);
  147.             println!("var2 grad: {}", tape.vars[var2_idx].grad);
  148.         }
  149.     }
  150.     let duration = start.elapsed().as_micros() as f64 / 1000.0;
  151.     println!("\nElapsed time: {:.2} ms\n", duration);
  152. }
  153.  
Advertisement
Add Comment
Please, Sign In to add comment