Advertisement
Guest User

Untitled

a guest
Dec 28th, 2018
90
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Rust 3.76 KB | None | 0 0
  1. use super::Matrix;
  2. use super::Network;
  3. use super::Vector;
  4. use std::cell::RefCell;
  5. use std::rc::Rc;
  6.  
  7. pub struct Layer {
  8.   vector: Rc<RefCell<Vector>>,
  9.   gradient: Rc<RefCell<Vector>>,
  10. }
  11. pub struct Weight {
  12.   matrix: Rc<RefCell<Matrix>>,
  13.   deltas: Rc<RefCell<Matrix>>,
  14. }
  15. pub struct Kernel {
  16.   input: Layer,
  17.   output: Layer,
  18.   weight: Weight,
  19. }
  20.  
  21. struct TrainerOptions {
  22.   step: f32,
  23.   momentum: f32,
  24. }
  25.  
  26. /// Stochastic Gradient Descent Trainer. This trainer implements a simple
  27. /// SGD back propagation algorithm on the given network. Provides a proxy
  28. /// forward() function to the underlying network. Training done on the
  29. /// backward() function.
  30. pub struct Trainer {
  31.   network: Network,
  32.   kernels: Vec<Kernel>,
  33.   options: TrainerOptions,
  34. }
  35. impl Trainer {
  36.   pub fn new(network: Network) -> Trainer {
  37.     let options = TrainerOptions {
  38.       step: 0.15,
  39.       momentum: 0.005,
  40.     };
  41.  
  42.     let gradients = network
  43.       .vectors
  44.       .iter()
  45.       .map(|v| {
  46.         let v = v.borrow();
  47.         Rc::new(RefCell::new(Vector::new(v.width)))
  48.       })
  49.       .collect::<Vec<_>>();
  50.  
  51.     let deltas = network
  52.       .matrices
  53.       .iter()
  54.       .map(|m| {
  55.         let m = m.borrow();
  56.         Rc::new(RefCell::new(Matrix::new(m.width, m.height)))
  57.       })
  58.       .collect::<Vec<_>>();
  59.  
  60.     let kernels = (0..network.matrices.len())
  61.       .map(|n| Kernel {
  62.         input: Layer {
  63.           vector: network.vectors[n + 0].clone(),
  64.           gradient: gradients[n + 0].clone(),
  65.         },
  66.         output: Layer {
  67.           vector: network.vectors[n + 1].clone(),
  68.           gradient: gradients[n + 1].clone(),
  69.         },
  70.         weight: Weight {
  71.           matrix: network.matrices[n + 0].clone(),
  72.           deltas: deltas[n + 0].clone(),
  73.         },
  74.       })
  75.       .collect::<Vec<_>>();
  76.  
  77.     Trainer {
  78.       kernels,
  79.       network,
  80.       options,
  81.     }
  82.   }
  83.  
  84.   pub fn forward(&mut self, input: Vec<f32>) -> Vec<f32> {
  85.     self.network.forward(input)
  86.   }
  87.  
  88.   fn derive(&self, x: f32) -> f32 {
  89.     1.0 - x * x
  90.   }
  91.  
  92.   pub fn backward(&mut self, input: Vec<f32>, expect: Vec<f32>) {
  93.     // phase 0: execute the network, write to output layer.
  94.     self.network.forward(input);
  95.  
  96.     // phase 1: calculate output layer gradients.
  97.     let kernel = self.kernels.last().unwrap();
  98.     for o in 0..kernel.weight.matrix.borrow().height {
  99.       let delta = expect[o] - kernel.output.vector.borrow()[o];
  100.       let value = delta * self.derive(kernel.output.vector.borrow()[0]);
  101.       kernel.output.gradient.borrow_mut()[o] = value;
  102.     }
  103.  
  104.     // phase 2: calculate gradients on hidden layers.
  105.     for k in (0..self.kernels.len()).rev() {
  106.       let kernel = &self.kernels[k];
  107.       let mut input_gradient = kernel.input.gradient.borrow_mut();
  108.       let output_gradient = kernel.output.gradient.borrow();
  109.       let matrix = kernel.weight.matrix.borrow();
  110.       input_gradient.set(matrix.backward(&output_gradient).data);
  111.       input_gradient.derive(|x| x * self.derive(*x));
  112.     }
  113.  
  114.     // phase 3: gradient decent on the weights.
  115.     for k in (0..self.kernels.len()).rev() {
  116.       let kernel = &self.kernels[k];
  117.       let mut matrix = kernel.weight.matrix.borrow_mut();
  118.       let mut deltas = kernel.weight.deltas.borrow_mut();
  119.       let input_vector = kernel.input.vector.borrow();
  120.       let output_gradient = kernel.output.gradient.borrow();
  121.  
  122.       for i in 0..matrix.width {
  123.         for o in 0..matrix.height {
  124.           let old_delta = deltas[(i, o)];
  125.           let new_delta = (self.options.step * input_vector[i] * output_gradient[o])
  126.             + (self.options.momentum * old_delta);
  127.  
  128.           let new_weight = matrix[(i, o)] + new_delta;
  129.           matrix[(i, o)] = new_weight;
  130.           deltas[(i, o)] = new_delta;
  131.         }
  132.       }
  133.     }
  134.   }
  135. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement