Advertisement
Guest User

Untitled

a guest
Dec 4th, 2016
79
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. export interface Descriptor {
  2.   /** the activation type. */
  3.   activation? : "linear" | "sigmoid" | "hypertan" | "relu" | "softmax"
  4.   /** the number of neurons in this layer. */
  5.   neurons     : number
  6. }
  7.  
  8. export interface Activation {
  9.   (value: number) : number
  10. }
  11.  
  12. export interface Layer {
  13.   /** this layers index. */
  14.   index      : number,
  15.   /** this layers activation function. */
  16.   activation : number,
  17.   /** the neuron indices for this layer. */
  18.   neurons    : Array<number>
  19. }
  20.  
  21. export interface Synapse {
  22.   /** the index for this synapse. */
  23.   index      : number,
  24.   /** the input neuron for this synapse. */
  25.   input      : number,
  26.   /** the output neuron for this synapse */
  27.   output     : number,
  28.   /** the weight index for this synapse. */
  29.   weight     : number
  30. }
  31.  
  32. export interface Neuron {
  33.   /** this neurons index, maps to the networks values array. */
  34.   index      : number,
  35.   /** the index layer for which this neuron exists. */
  36.   layer      : number,
  37.   /** input synapses for this neuron. */
  38.   inputs     : Array<number>,
  39.   /** output synapses for this neuron */
  40.   outputs    : Array<number>
  41. }
  42.  
  43. export class Network {
  44.   public layers      : Array<Layer>
  45.   public synapses    : Array<Synapse>
  46.   public neurons     : Array<Neuron>
  47.   public activations : Array<Activation>
  48.   public values      : Float32Array
  49.   public weights     : Float32Array
  50.  
  51.   /**
  52.    * creates a new network with the given descriptor.
  53.    * @param {Array<LayerDescriptor>} an array of layer descriptors defining the network topology.
  54.    * @returns {Network}
  55.    */
  56.   constructor(descriptors: Array<Descriptor>) {
  57.  
  58.     //----------------------------------------
  59.     // activation mapping function.
  60.     //----------------------------------------  
  61.     const activation_map = (activation: string) : number => {
  62.       switch(activation) {
  63.         case undefined:  return 0;
  64.         case "linear":   return 0;
  65.         case "sigmoid":  return 1;
  66.         case "hypertan": return 2;
  67.         case "relu":     return 3;
  68.         case "softmax":  return 4;
  69.         default: throw Error("unknown activation type " + activation)
  70.       }
  71.     }
  72.  
  73.     // create activation functions.
  74.     this.activations = [
  75.       /* 0: linear   */ (x: number) => x,
  76.       /* 1: sigmoid  */ (x: number) => 1.0 / (1.0 + Math.exp(-x)),
  77.       /* 2: hypertan */ (x: number) => x,
  78.       /* 3: relu     */ (x: number) => x,
  79.       /* 4: softmax  */ (x: number) => x,
  80.     ]
  81.  
  82.     // create network layers.
  83.     let neuron_index = 0;
  84.     this.layers = descriptors.map((descriptor, index) => {
  85.       let layer = {
  86.         index     : index,
  87.         activation: activation_map(descriptor.activation),
  88.         neurons   : []
  89.       }
  90.       for(let i = 0; i < (descriptor.neurons + 1); i++) {
  91.         layer.neurons.push(neuron_index)
  92.         neuron_index += 1;
  93.       } return layer
  94.     })
  95.  
  96.     // create neurons.
  97.     this.neurons = []
  98.     neuron_index = 0;
  99.     for(let i = 0; i < this.layers.length; i++) {
  100.       for(let j = 0; j < this.layers[i].neurons.length; j++) {
  101.         this.neurons.push({
  102.           index  : neuron_index,
  103.           layer  : i,
  104.           inputs : [],
  105.           outputs: []
  106.         }); neuron_index += 1;
  107.       }
  108.     }
  109.  
  110.     // create synapses.
  111.     let synapse_index = 0
  112.     this.synapses = []
  113.     for(let i = 0; i < (this.layers.length - 1); i++) {
  114.       let input  = this.layers[i+0]
  115.       let output = this.layers[i+1]
  116.       for(let output_idx = 1; output_idx < output.neurons.length; output_idx += 1 ) {
  117.         for(let input_idx = 0; input_idx < input.neurons.length; input_idx += 1 ) {
  118.           this.synapses.push({
  119.             index  : synapse_index,
  120.             input  : input.neurons[input_idx],
  121.             output : output.neurons[output_idx],
  122.             weight : synapse_index
  123.           }); synapse_index += 1;
  124.         }
  125.       }
  126.     }
  127.  
  128.     // neuron -> synapse -> neuron.
  129.     for(let i = 0; i < this.synapses.length; i+=1) {
  130.       let input   = this.neurons[this.synapses[i].input]
  131.       let output  = this.neurons[this.synapses[i].output]
  132.       input.outputs.push(this.synapses[i].index)
  133.       output.inputs.push(this.synapses[i].index)
  134.     }
  135.  
  136.     // setup buffers
  137.     this.values  = new Float32Array(this.neurons.length)
  138.     this.weights = new Float32Array(this.synapses.length)
  139.  
  140.     // setup bias values + weights
  141.     for(let i = 0; i < this.layers.length; i += 1) {
  142.       let layer = this.layers[i]
  143.       let neuron = this.neurons[layer.neurons[0]]
  144.       this.values[neuron.index] = 1.0
  145.       for(let j = 0; j < neuron.outputs.length; j += 1) {
  146.         let synapse = this.synapses[neuron.outputs[j]]
  147.         this.weights[synapse.index] = 1.0
  148.       }
  149.     }
  150.  
  151.     // setup weights.
  152.     for(let i = 0; i < this.layers.length; i += 1) {
  153.       let layer = this.layers[i]
  154.       for(let j = 1; j < layer.neurons.length; j += 1) {
  155.         let neuron = this.neurons[layer.neurons[j]]
  156.         for(let k = 0; k < neuron.outputs.length; k++) {
  157.           let synapse = this.synapses[neuron.outputs[k]]
  158.           this.weights[synapse.weight] = 0.5 // should be random.
  159.         }
  160.       }
  161.     }
  162.   }
  163.  
  164.   /**
  165.    * forward feeds the given value through the network and returns an result.
  166.    * @param {Array<number>} the inputs to this network.
  167.    * @returns {Array<number>} the network output.
  168.    */
  169.   public forward(input: Array<number>): Array<number> {
  170.     // validate input.
  171.     if(this.layers[0].neurons.length - 1 !== input.length) {
  172.       throw Error("input length mismatch.")
  173.     }
  174.     // load input (layer + 0)
  175.     for(let i = 0; i < input.length; i += 1) {
  176.       let neuron = this.neurons[this.layers[0].neurons[i + 1]]
  177.       this.values[neuron.index] = input[i]
  178.     }
  179.     // feed forward (layer + 1)
  180.     for(let i = 1; i < this.layers.length; i += 1) {
  181.       let layer = this.layers[i]
  182.       for(let j = 1; j < layer.neurons.length; j += 1) {
  183.         let output = this.neurons[layer.neurons[j]]
  184.         let accumulator = 0
  185.         for(let k = 0; k < output.inputs.length; k += 1) {
  186.           let synapse = this.synapses[output.inputs[k]]
  187.           let input   = this.neurons[synapse.input]
  188.           accumulator += (this.values [input.index] * this.weights[synapse.weight])
  189.         }
  190.         let activation = this.activations[layer.activation]
  191.         this.values[output.index] = activation(accumulator)
  192.       }
  193.     }
  194.     // read output.
  195.     return this.layers[this.layers.length - 1].neurons
  196.     .filter((_, i) => i > 0).map(index => {
  197.       let neuron = this.neurons[index]
  198.       return this.values[neuron.index]
  199.     })
  200.   }
  201. }
  202.  
  203. let network = new Network([
  204.   {neurons: 2},
  205.   {neurons: 2000, activation: "sigmoid"},
  206.   {neurons: 2000, activation: "sigmoid"},
  207.   {neurons: 2000, activation: "sigmoid"},
  208.   {neurons: 1,    activation: "sigmoid"}
  209. ])
  210.  
  211. console.log( network.forward([0, 1]))
  212. console.log( network.forward([0, 1]))
  213. console.log( network.forward([0, 1]))
  214. console.log( network.forward([0, 1]))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement