Advertisement
Guest User

Untitled

a guest
Dec 28th, 2018
231
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Rust 1.80 KB | None | 0 0
  1. use super::Matrix;
  2. use super::Vector;
  3. use std::cell::RefCell;
  4. use std::rc::Rc;
  5.  
  6. #[derive(Debug)]
  7. pub struct Network {
  8.   pub vectors: Vec<Rc<RefCell<Vector>>>,
  9.   pub matrices: Vec<Rc<RefCell<Matrix>>>,
  10. }
  11.  
  12. impl Network {
  13.   /// Creates a new network with the given vectors and matrices.
  14.   #[allow(dead_code)]
  15.   pub fn new(vectors: Vec<Rc<RefCell<Vector>>>, matrices: Vec<Rc<RefCell<Matrix>>>) -> Network {
  16.     Network { vectors, matrices }
  17.   }
  18.  
  19.   /// Executes this network in a feed forward fashion.
  20.   #[allow(dead_code)]
  21.   pub fn forward(&mut self, input: Vec<f32>) -> Vec<f32> {
  22.     // load data into input vector.
  23.     self.vectors[0].borrow_mut().set(input);
  24.  
  25.     // feed forward data.
  26.     for i in 0..self.matrices.len() {
  27.  
  28.       let mut output = self.vectors [i + 1].borrow_mut();
  29.       let input      = self.vectors [i + 0].borrow();
  30.       let matrix     = self.matrices[i + 0].borrow();
  31.      
  32.       // multiply forward and write to output.
  33.       output.set(matrix.forward(&input).data);
  34.       output.activate(|x| {
  35.         let exp_0 = (-x).exp();
  36.         let exp_1 = (x).exp();
  37.         (exp_0 - exp_1) / (exp_0 + exp_1)
  38.       });
  39.     }
  40.  
  41.     // unload data
  42.     self.vectors.last().unwrap().borrow().data.clone()
  43.   }
  44.  
  45.   /// Creates a new network with the given size constraints.
  46.   #[allow(dead_code)]
  47.   pub fn create(sizes: Vec<usize>) -> Network {
  48.     let vectors = sizes
  49.       .iter()
  50.       .map(|size| Rc::new(RefCell::new(Vector::new(*size))))
  51.       .collect::<Vec<_>>();
  52.  
  53.     let matrices = (0..vectors.len() - 1)
  54.       .map(|i| {
  55.         let input = vectors[i].borrow();
  56.         let output = vectors[i + 1].borrow();
  57.         Rc::new(RefCell::new(Matrix::random(input.width, output.width)))
  58.       })
  59.       .collect::<Vec<_>>();
  60.  
  61.     Network::new(vectors, matrices)
  62.   }
  63. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement