MrThoe

Neural Network in p5

May 31st, 2022
829
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. /** This is a neural network written in P5.
  2.   * @Author:  Daniel Shiffman
  3.   * @Source:  https://www.youtube.com/watch?v=YtRA6tqgJBc
  4.   */
  5.  
  6. class ActivationFunction {
  7.   constructor(func, dfunc) {
  8.     this.func = func;
  9.     this.dfunc = dfunc;
  10.   }
  11. }
  12.  
  13. let sigmoid = new ActivationFunction(
  14.   x => 1 / (1 + Math.exp(-x)),
  15.   y => y * (1 - y)
  16. );
  17.  
  18. let tanh = new ActivationFunction(
  19.   x => Math.tanh(x),
  20.   y => 1 - (y * y)
  21. );
  22.  
  23.  
  24. class NeuralNetwork {
  25.   // TODO: document what a, b, c are
  26.   constructor(a, b, c) {
  27.     if (a instanceof NeuralNetwork) {
  28.       this.input_nodes = a.input_nodes;
  29.       this.hidden_nodes = a.hidden_nodes;
  30.       this.output_nodes = a.output_nodes;
  31.  
  32.       this.weights_ih = a.weights_ih.copy();
  33.       this.weights_ho = a.weights_ho.copy();
  34.  
  35.       this.bias_h = a.bias_h.copy();
  36.       this.bias_o = a.bias_o.copy();
  37.     } else {
  38.       this.input_nodes = a;
  39.       this.hidden_nodes = b;
  40.       this.output_nodes = c;
  41.  
  42.       this.weights_ih = new Matrix(this.hidden_nodes, this.input_nodes);
  43.       this.weights_ho = new Matrix(this.output_nodes, this.hidden_nodes);
  44.       this.weights_ih.randomize();
  45.       this.weights_ho.randomize();
  46.  
  47.       this.bias_h = new Matrix(this.hidden_nodes, 1);
  48.       this.bias_o = new Matrix(this.output_nodes, 1);
  49.       this.bias_h.randomize();
  50.       this.bias_o.randomize();
  51.     }
  52.  
  53.     // TODO: copy these as well
  54.     this.setLearningRate();
  55.     this.setActivationFunction();
  56.  
  57.  
  58.   }
  59.  
  60.   predict(input_array) {
  61.  
  62.     // Generating the Hidden Outputs
  63.     let inputs = Matrix.fromArray(input_array);
  64.     let hidden = Matrix.multiply(this.weights_ih, inputs);
  65.     hidden.add(this.bias_h);
  66.     // activation function!
  67.     hidden.map(this.activation_function.func);
  68.  
  69.     // Generating the output's output!
  70.     let output = Matrix.multiply(this.weights_ho, hidden);
  71.     output.add(this.bias_o);
  72.     output.map(this.activation_function.func);
  73.  
  74.     // Sending back to the caller!
  75.     return output.toArray();
  76.   }
  77.  
  78.   setLearningRate(learning_rate = 0.1) {
  79.     this.learning_rate = learning_rate;
  80.   }
  81.  
  82.   setActivationFunction(func = sigmoid) {
  83.     this.activation_function = func;
  84.   }
  85.  
  86.   train(input_array, target_array) {
  87.     // Generating the Hidden Outputs
  88.     let inputs = Matrix.fromArray(input_array);
  89.     let hidden = Matrix.multiply(this.weights_ih, inputs);
  90.     hidden.add(this.bias_h);
  91.     // activation function!
  92.     hidden.map(this.activation_function.func);
  93.  
  94.     // Generating the output's output!
  95.     let outputs = Matrix.multiply(this.weights_ho, hidden);
  96.     outputs.add(this.bias_o);
  97.     outputs.map(this.activation_function.func);
  98.  
  99.     // Convert array to matrix object
  100.     let targets = Matrix.fromArray(target_array);
  101.  
  102.     // Calculate the error
  103.     // ERROR = TARGETS - OUTPUTS
  104.     let output_errors = Matrix.subtract(targets, outputs);
  105.  
  106.     // let gradient = outputs * (1 - outputs);
  107.     // Calculate gradient
  108.     let gradients = Matrix.map(outputs, this.activation_function.dfunc);
  109.     gradients.multiply(output_errors);
  110.     gradients.multiply(this.learning_rate);
  111.  
  112.  
  113.     // Calculate deltas
  114.     let hidden_T = Matrix.transpose(hidden);
  115.     let weight_ho_deltas = Matrix.multiply(gradients, hidden_T);
  116.  
  117.     // Adjust the weights by deltas
  118.     this.weights_ho.add(weight_ho_deltas);
  119.     // Adjust the bias by its deltas (which is just the gradients)
  120.     this.bias_o.add(gradients);
  121.  
  122.     // Calculate the hidden layer errors
  123.     let who_t = Matrix.transpose(this.weights_ho);
  124.     let hidden_errors = Matrix.multiply(who_t, output_errors);
  125.  
  126.     // Calculate hidden gradient
  127.     let hidden_gradient = Matrix.map(hidden, this.activation_function.dfunc);
  128.     hidden_gradient.multiply(hidden_errors);
  129.     hidden_gradient.multiply(this.learning_rate);
  130.  
  131.     // Calcuate input->hidden deltas
  132.     let inputs_T = Matrix.transpose(inputs);
  133.     let weight_ih_deltas = Matrix.multiply(hidden_gradient, inputs_T);
  134.  
  135.     this.weights_ih.add(weight_ih_deltas);
  136.     // Adjust the bias by its deltas (which is just the gradients)
  137.     this.bias_h.add(hidden_gradient);
  138.  
  139.     // outputs.print();
  140.     // targets.print();
  141.     // error.print();
  142.   }
  143.  
  144.   serialize() {
  145.     return JSON.stringify(this);
  146.   }
  147.  
  148.   static deserialize(data) {
  149.     if (typeof data == 'string') {
  150.       data = JSON.parse(data);
  151.     }
  152.     let nn = new NeuralNetwork(data.input_nodes, data.hidden_nodes, data.output_nodes);
  153.     nn.weights_ih = Matrix.deserialize(data.weights_ih);
  154.     nn.weights_ho = Matrix.deserialize(data.weights_ho);
  155.     nn.bias_h = Matrix.deserialize(data.bias_h);
  156.     nn.bias_o = Matrix.deserialize(data.bias_o);
  157.     nn.learning_rate = data.learning_rate;
  158.     return nn;
  159.   }
  160.  
  161.  
  162.   // Adding function for neuro-evolution
  163.   copy() {
  164.     return new NeuralNetwork(this);
  165.   }
  166.  
  167.   // Accept an arbitrary function for mutation
  168.   mutate(func) {
  169.     this.weights_ih.map(func);
  170.     this.weights_ho.map(func);
  171.     this.bias_h.map(func);
  172.     this.bias_o.map(func);
  173.   }
  174.  
  175. }
  176.  
Advertisement
Add Comment
Please, Sign In to add comment