SHARE
TWEET

Neural Network Example

a guest Mar 25th, 2017 60 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. /**
  2.  * Neural Network
  3.  * Feedforward Backpropagation Neural Network
  4.  * Written in 2002 by Jeff Heaton(http://www.jeffheaton.com)
  5.  *
  6.  * This class is released under the limited GNU public
  7.  * license (LGPL).
  8.  *
  9.  * @author Jeff Heaton
  10.  * @version 1.0
  11.  */
  12.  
  13. public class Network {
  14.  
  15.  /**
  16.   * The global error for the training.
  17.   */
  18.     protected double globalError;
  19.  
  20.  /**
  21.   * The number of input neurons.
  22.   */
  23.     protected int inputCount;
  24.  
  25.  /**
  26.   * The number of hidden neurons.
  27.   */
  28.     protected int hiddenCount;
  29.  
  30.  /**
  31.   * The number of output neurons
  32.   */
  33.     protected int outputCount;
  34.  
  35.  /**
  36.   * The total number of neurons in the network.
  37.   */
  38.     protected int neuronCount;
  39.  
  40.  /**
  41.   * The number of weights in the network.
  42.   */
  43.     protected int weightCount;
  44.  
  45.  /**
  46.   * The learning rate.
  47.   */
  48.     protected double learnRate;
  49.  
  50.  /**
  51.   * The outputs from the various levels.
  52.   */
  53.     protected double neuronOutputs[];
  54.  
  55.  /**
  56.   * The weight matrix this, along with the thresholds can be
  57.   * thought of as the "memory" of the neural network.
  58.   */
  59.     protected double weightArray[];
  60.  
  61.  /**
  62.   * The errors from the last calculation.
  63.   */
  64.     protected double error[];
  65.  
  66.  /**
  67.   * Accumulates matrix delta's for training.
  68.   */
  69.     protected double accMatrixDelta[];
  70.  
  71.  /**
  72.   * The thresholds, this value, along with the weight matrix
  73.   * can be thought of as the memory of the neural network.
  74.   */
  75.     protected double thresholds[];
  76.  
  77.  /**
  78.   * The changes that should be applied to the weight
  79.   * matrix.
  80.   */
  81.     protected double matrixDelta[];
  82.  
  83.  /**
  84.   * The accumulation of the threshold deltas.
  85.   */
  86.     protected double accThresholdDelta[];
  87.  
  88.  /**
  89.   * The threshold deltas.
  90.   */
  91.     protected double thresholdDelta[];
  92.  
  93.  /**
  94.   * The momentum for training.
  95.   */
  96.     protected double momentum;
  97.  
  98.  /**
  99.   * The changes in the errors.
  100.   */
  101.     protected double errorDelta[];
  102.  
  103.  
  104.  /**
  105.   * Construct the neural network.
  106.   *
  107.   * @param inputCount The number of input neurons.
  108.   * @param hiddenCount The number of hidden neurons
  109.   * @param outputCount The number of output neurons
  110.   * @param learnRate The learning rate to be used when training.
  111.   * @param momentum The momentum to be used when training.
  112.   */
  113.     public Network(int inputCount, int hiddenCount, int outputCount, double learnRate, double momentum) {
  114.  
  115.         this.learnRate = learnRate;
  116.         this.momentum = momentum;
  117.    
  118.         this.inputCount = inputCount;
  119.         this.hiddenCount = hiddenCount;
  120.         this.outputCount = outputCount;
  121.         neuronCount = inputCount + hiddenCount + outputCount;
  122.      
  123.         weightCount = (inputCount * hiddenCount) + (hiddenCount * outputCount);
  124.    
  125.         neuronOutputs = new double[neuronCount];
  126.         weightArray = new double[weightCount];
  127.         matrixDelta = new double[weightCount];
  128.         thresholds = new double[neuronCount];
  129.         errorDelta = new double[neuronCount];
  130.         error = new double[neuronCount];
  131.         accThresholdDelta = new double[neuronCount];
  132.         accMatrixDelta = new double[weightCount];
  133.         thresholdDelta = new double[neuronCount];
  134.    
  135.         reset();
  136.     }
  137.  
  138.  
  139.  
  140.      /**
  141.       * Returns the root mean square error for a complete training set.
  142.       *
  143.       * @param len The length of a complete training set.
  144.       * @return The current error for the neural network.
  145.       */
  146.     public double getError(int len) {
  147.         double err = Math.sqrt(globalError / (len * outputCount));
  148.         globalError = 0; // clear the accumulator
  149.         return err;
  150.     }
  151.    
  152.      /**
  153.       * The threshold method. You may wish to override this class to provide other
  154.       * threshold methods.
  155.       *
  156.       * @param sum The activation from the neuron.
  157.       * @return The activation applied to the threshold method.
  158.       */
  159.     public double threshold(double sum) {
  160.         return 1/(1 + Math.exp(-sum));
  161.     }
  162.    
  163.      /**
  164.       * Compute the output for a given input to the neural network.
  165.       *
  166.       * @param input The input provide to the neural network.
  167.       * @return The results from the output neurons.
  168.       */
  169.     public double []computeOutputs(double input[]) {
  170.         int i, j;
  171.         final int hiddenIndex = inputCount;
  172.         final int outIndex = inputCount + hiddenCount;
  173.    
  174.         for (i = 0; i < inputCount; i++) {
  175.             neuronOutputs[i] = input[i];
  176.         }
  177.    
  178.       // first layer
  179.         int inx = 0;
  180.    
  181.         for (i = hiddenIndex; i < outIndex; i++) {
  182.             double sum = thresholds[i];
  183.    
  184.             for (j = 0; j < inputCount; j++) {
  185.                 sum += neuronOutputs[j] * weightArray[inx++];
  186.             }
  187.             neuronOutputs[i] = threshold(sum);
  188.         }
  189.    
  190.         // hidden layer
  191.    
  192.         double result[] = new double[outputCount];
  193.    
  194.         for (i = outIndex; i < neuronCount; i++) {
  195.             double sum = thresholds[i];
  196.            
  197.             for (j = hiddenIndex; j < outIndex; j++) {
  198.                 sum += neuronOutputs[j] * weightArray[inx++];
  199.             }
  200.            
  201.             neuronOutputs[i] = threshold(sum);
  202.             result[i-outIndex] = neuronOutputs[i];
  203.             }
  204.        
  205.         return result;
  206.         }
  207.    
  208.    
  209.      /**
  210.       * Calculate the error for the recogntion just done.
  211.       *
  212.       * @param ideal What the output neurons should have yielded.
  213.       */
  214.     public void calcError(double ideal[]) {
  215.         int i, j;
  216.         final int hiddenIndex = inputCount;
  217.         final int outputIndex = inputCount + hiddenCount;
  218.    
  219.         // clear hidden layer errors
  220.         for (i = inputCount; i < neuronCount; i++) {
  221.             error[i] = 0;
  222.         }
  223.    
  224.         // layer errors and deltas for output layer
  225.         for (i = outputIndex; i < neuronCount; i++) {
  226.             error[i] = ideal[i - outputIndex] - neuronOutputs[i];
  227.             globalError += error[i] * error[i];
  228.             errorDelta[i] = error[i] * neuronOutputs[i] * (1 - neuronOutputs[i]);
  229.         }
  230.    
  231.         // hidden layer errors
  232.         int winx = inputCount * hiddenCount;
  233.    
  234.         for (i = outputIndex; i < neuronCount; i++) {
  235.             for (j = hiddenIndex; j < outputIndex; j++) {
  236.                 accMatrixDelta[winx] += errorDelta[i] * neuronOutputs[j];
  237.                 error[j] += weightArray[winx] * errorDelta[i];
  238.                 winx++;
  239.             }
  240.             accThresholdDelta[i] += errorDelta[i];
  241.         }
  242.    
  243.         // hidden layer deltas
  244.         for (i = hiddenIndex; i < outputIndex; i++) {
  245.             errorDelta[i] = error[i] * neuronOutputs[i] * (1 - neuronOutputs[i]);
  246.         }
  247.    
  248.         // input layer errors
  249.         winx = 0; // offset into weight array
  250.         for (i = hiddenIndex; i < outputIndex; i++) {
  251.             for (j = 0; j < hiddenIndex; j++) {
  252.                 accMatrixDelta[winx] += errorDelta[i] * neuronOutputs[j];
  253.                 error[j] += weightArray[winx] * errorDelta[i];
  254.                 winx++;
  255.             }
  256.             accThresholdDelta[i] += errorDelta[i];
  257.         }
  258.     }
  259.    
  260.      /**
  261.       * Modify the weight matrix and thresholds based on the last call to
  262.       * calcError.
  263.       */
  264.     public void learn() {
  265.         int i;
  266.    
  267.         // process the matrix
  268.         for (i = 0; i < weightArray.length; i++) {
  269.             matrixDelta[i] = (learnRate * accMatrixDelta[i]) + (momentum * matrixDelta[i]);
  270.             weightArray[i] += matrixDelta[i];
  271.             accMatrixDelta[i] = 0;
  272.         }
  273.    
  274.         // process the thresholds
  275.         for (i = inputCount; i < neuronCount; i++) {
  276.             thresholdDelta[i] = learnRate * accThresholdDelta[i] + (momentum * thresholdDelta[i]);
  277.             thresholds[i] += thresholdDelta[i];
  278.             accThresholdDelta[i] = 0;
  279.         }
  280.     }
  281.    
  282.      /**
  283.       * Reset the weight matrix and the thresholds.
  284.       */
  285.     public void reset() {
  286.         int i;
  287.    
  288.         for (i = 0; i < neuronCount; i++) {
  289.             thresholds[i] = 0.5 - (Math.random());
  290.             thresholdDelta[i] = 0;
  291.             accThresholdDelta[i] = 0;
  292.         }
  293.        
  294.         for (i = 0; i < weightArray.length; i++) {
  295.             weightArray[i] = 0.5 - (Math.random());
  296.             matrixDelta[i] = 0;
  297.             accMatrixDelta[i] = 0;
  298.         }
  299.     }
  300. }
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
Not a member of Pastebin yet?
Sign Up, it unlocks many cool features!
 
Top