Pastebin launched a little side project called VERYVIRAL.com, check it out ;-) Want more features on Pastebin? Sign Up, it's FREE!
Guest

Java BackPropLearning issue

By: tzmtn on Apr 18th, 2012  |  syntax: Java  |  size: 7.65 KB  |  views: 35  |  expires: Never
download  |  raw  |  embed  |  report abuse  |  print
Text below is selected. Please press Ctrl+C to copy to your clipboard. (⌘+C on Mac)
  1. public class Tester
  2. {
  3.         public static void main(String[] args)
  4.         {
  5.                 double[][] inputs = {
  6.                         {0.547361, -2.04845, -2.71647},
  7.                         {1.9708, -1.16141, -0.0485735},
  8.                         {-3.18799, 2.97068, 0.26499},
  9.                         {-0.498425, -1.04703, 1.61744}
  10.                 };
  11.  
  12.                 double[][] outputs = {
  13.                         {1.0, 0.0, 0.0, 0.0},
  14.                         {0.0, 1.0, 0.0, 0.0},
  15.                         {0.0, 0.0, 1.0, 0.0},
  16.                         {0.0, 0.0, 0.0, 1.0}
  17.                 };
  18.                
  19.                 Example[] examples = new Example[inputs.length];
  20.                 for (int i = 0; i < examples.length; i++)
  21.                         examples[i] = new Example(inputs[i], outputs[i]);
  22.                
  23.                 int[] hiddenLayers = {4,4};  // sizes of hidden layers
  24.                 Network net = new Network();
  25.                
  26.                 net.backPropLearning(examples, hiddenLayers, 10000, 1.0);
  27.         }
  28. }
  29.  
  30.  
  31. public class Network
  32. {
  33.         private Layer[] layers;
  34.         static double reportMultiplier = -1.0;
  35.        
  36.         private void setLayer (Layer layer, int i)
  37.         {
  38.                 layers[i] = layer;
  39.         }
  40.        
  41.         private void createLayers (int numInputs, int numOutputs, int[] hiddenLayers)
  42.         {
  43.                 layers = new Layer[hiddenLayers.length + 2];
  44.                
  45.                 Layer inputLayer = new Layer(numInputs);
  46.                 setLayer(inputLayer, 0);
  47.                
  48.                 for (int i = 0; i < hiddenLayers.length; i++)
  49.                         setLayer(new Layer(layers[i], hiddenLayers[i]), i + 1);
  50.                
  51.                 Layer outputLayer = new Layer(layers[hiddenLayers.length], numOutputs);
  52.                 setLayer(outputLayer, layers.length - 1);
  53.                
  54.                 for (int i = 0; i < layers.length; i++)
  55.                         layers[i].createNeurons();
  56.         }
  57.        
  58.         private double activation (double x)
  59.         {
  60.                 return 1.0 / (1.0 + Math.exp(-x)); // sigmoid
  61.         }
  62.        
  63.         private double activationDerivative (double x)
  64.         {
  65.                 double exp = Math.exp(x);
  66.                 return exp / Math.pow((exp + 1.0), 2); // d/dx sigmoid
  67.         }
  68.        
  69.         private void printReport (int n, Example[] examples)
  70.         {
  71.                 String s = "";
  72.                 double errorSum = 0.0;
  73.                 int m = layers.length - 1;
  74.                 for (int i = 0; i < layers[m].getSize(); i++)
  75.                         for (int j = 0; j < examples.length; j++)
  76.                                 errorSum += Math.abs(examples[j].getOutput(i) - (layers[m].getNeuron(i)).getOutput());
  77.                
  78.                 s = "n:" + n;
  79.                 while (s.length() < 10)
  80.                         s += " ";
  81.                
  82.                 s += "error:" + (float) errorSum;
  83.                 while (s.length() < 30)
  84.                         s += " ";
  85.                
  86.                 System.out.print(s);
  87.                
  88.                 if (reportMultiplier < 0)
  89.                         reportMultiplier = 50 / errorSum;
  90.                
  91.                 for (int i = 0; i < errorSum * reportMultiplier; i++)
  92.                         System.out.print("|");
  93.                
  94.                 System.out.println();
  95.         }
  96.        
  97.         public void backPropLearning (Example[] examples, int[] hiddenLayers, int steps, double learningFactor)
  98.         {
  99.                 createLayers(examples[0].numInputs(), examples[0].numOutputs(), hiddenLayers);
  100.                 int n = 0;
  101.                 int numPlots = 50;
  102.                 int plotStep = Math.max(1, steps / numPlots);
  103.                
  104.                 while (n < steps) {
  105.                         for (int e = 0; e < examples.length; e++) {
  106.                                 Example example = examples[e];
  107.                                
  108.                                 // Propagate the inputs forward to compute the output
  109.                                 for (int i = 0; i < layers[0].getSize(); i++)
  110.                                         layers[0].getNeuron(i).setOutput(example.getInput(i));
  111.                                
  112.                                 for (int i = 1; i < layers.length; i++) {
  113.                                         for (int j = 0; j < layers[i].getSize(); j++) {
  114.                                                 Neuron nj = layers[i].getNeuron(j);
  115.                                                 double in = 0.0;
  116.                                                 for (int k = 0; k < nj.numInputs(); k++) {
  117.                                                         Axon akj = nj.getInput(k);
  118.                                                         Neuron nk = akj.getStart();
  119.                                                         in += akj.getWeight() * nk.getOutput();
  120.                                                 }
  121.                                                 nj.setInput(in);
  122.                                                 nj.setOutput(activation(in));
  123.                                         }
  124.                                 }
  125.                                
  126.                                 // Propagate deltas backward from output to input layer
  127.                                 int m = layers.length - 1;
  128.                                 for (int i = 0; i < layers[m].getSize(); i++) {
  129.                                         Neuron ni = layers[m].getNeuron(i);
  130.                                         double errori = example.getOutput(i) - ni.getOutput();
  131.                                         ni.setDelta(activationDerivative(ni.getInput()) * errori);
  132.                                 }
  133.                                
  134.                                 for (int i = m - 1; i >= 0; i--) {
  135.                                         for (int j = 0; j < layers[i].getSize(); j++) {
  136.                                                 Neuron nj = layers[i].getNeuron(j);
  137.                                                 double delta = 0.0;
  138.                                                 for (int k = 0; k < nj.numOutputs(); k++) {
  139.                                                         Axon ajk = nj.getOutput(k);
  140.                                                         Neuron nk = ajk.getEnd();
  141.                                                         delta += ajk.getWeight() * nk.getDelta();
  142.                                                 }
  143.                                                 delta *= activationDerivative(nj.getInput());
  144.                                                 nj.setDelta(delta);
  145.                                         }
  146.                                 }
  147.                                
  148.                                 // Update every weight in network using deltas
  149.                                 for (int i = 0; i < layers.length - 1; i++) {
  150.                                         for (int j = 0; j < layers[i].getSize(); j++) {
  151.                                                 Neuron nj = layers[i].getNeuron(j);
  152.                                                
  153.                                                 for (int k = 0; k < nj.numOutputs(); k++) {
  154.                                                         Axon ajk = nj.getOutput(k);
  155.                                                         Neuron nk = ajk.getEnd();
  156.                                                         ajk.modifyWeight(learningFactor * nk.getDelta() * nj.getOutput());
  157.                                                         ajk.modifyWeight(learningFactor * nj.getDelta() * nk.getOutput());
  158.                                                 }
  159.                                         }
  160.                                 }
  161.                         }
  162.                        
  163.                         if (n % plotStep == 0)
  164.                                 printReport(n, examples);
  165.                        
  166.                         n++;
  167.                 }
  168.                
  169.                 printReport(n, examples);
  170.         }
  171. }
  172.  
  173.  
  174. public class Layer
  175. {
  176.         private Neuron[] neurons;
  177.         private Layer previous;
  178.         private Layer next;
  179.        
  180.         public Layer (int n)
  181.         {
  182.                 neurons = new Neuron[n];
  183.         }
  184.        
  185.         public Layer (Layer l, int n)
  186.         {
  187.                 this(n);
  188.                
  189.                 previous = l;
  190.                 previous.next = this;
  191.         }
  192.        
  193.         public void createNeurons ()
  194.         {
  195.                 int n = previous == null ? 0 : previous.neurons.length;
  196.                 int m = next == null ? 0 : next.neurons.length;
  197.                
  198.                 for (int i = 0; i < neurons.length; i++) {
  199.                         neurons[i] = new Neuron(n, m);
  200.                         for (int j = 0; j < n; j++)
  201.                                 neurons[i].addInput(previous.neurons[j]);
  202.                 }
  203.         }
  204.        
  205.         public int getSize ()
  206.         {
  207.                 return neurons.length;
  208.         }
  209.        
  210.         public Neuron getNeuron (int i)
  211.         {
  212.                 return neurons[i];
  213.         }
  214. }
  215.  
  216.  
  217. public class Neuron
  218. {
  219.         private Axon[] inputs;
  220.         private Axon[] outputs;
  221.         private double input;
  222.         private double output;
  223.         private double delta;
  224.        
  225.         public Neuron () // dummy constructor
  226.         {
  227.                 inputs = new Axon[0];
  228.                 outputs = new Axon[1];
  229.                 input = 1.0;
  230.                 output = 1.0;
  231.                 delta = 0.0;
  232.         }
  233.        
  234.         public Neuron (int n, int m)
  235.         {
  236.                 this();
  237.                 n += 1; // add the dummy input
  238.                 inputs = new Axon[n];
  239.                 outputs = new Axon[m];
  240.                
  241.                 Neuron dummy = new Neuron();
  242.                 addInput(dummy);
  243.         }
  244.        
  245.         private void append (Axon[] arr, Axon a)
  246.         {
  247.                 int i = 0;
  248.                 while (i < arr.length && arr[i] != null)
  249.                         i++;
  250.                 if (i < arr.length)
  251.                         arr[i] = a;
  252.         }
  253.        
  254.         private void addOutput (Axon a)
  255.         {
  256.                 append(outputs, a);
  257.         }
  258.        
  259.         public void addInput (Neuron n)
  260.         {
  261.                 Axon a = new Axon(n, this);
  262.                 append(inputs, a);
  263.                 n.addOutput(a);
  264.         }
  265.        
  266.         public double getInput ()
  267.         {
  268.                 return input;
  269.         }
  270.        
  271.         public void setInput (double in)
  272.         {
  273.                 input = in;
  274.         }
  275.        
  276.         public double getOutput ()
  277.         {
  278.                 return output;
  279.         }
  280.        
  281.         public void setOutput (double out)
  282.         {
  283.                 output = out;
  284.         }
  285.        
  286.         public double getDelta ()
  287.         {
  288.                 return delta;
  289.         }
  290.        
  291.         public void setDelta (double newDelta)
  292.         {
  293.                 delta = newDelta;
  294.         }
  295.        
  296.         public int numInputs ()
  297.         {
  298.                 return inputs.length;
  299.         }
  300.        
  301.         public int numOutputs ()
  302.         {
  303.                 return outputs.length;
  304.         }
  305.        
  306.         public Axon getInput (int i)
  307.         {
  308.                 return inputs[i];
  309.         }
  310.        
  311.         public Axon getOutput (int i)
  312.         {
  313.                 return outputs[i];
  314.         }
  315. }
  316.  
  317.  
  318. public class Axon
  319. {
  320.         private double weight;
  321.         private Neuron start;
  322.         private Neuron end;
  323.        
  324.         public Axon (Neuron s, Neuron e)
  325.         {
  326.                 start = s;
  327.                 end = e;
  328.                 weight = 2.0 * Math.random() - 1.0; // random initial weight between -1 and 1
  329.         }
  330.        
  331.         public Neuron getStart ()
  332.         {
  333.                 return start;
  334.         }
  335.        
  336.         public Neuron getEnd ()
  337.         {
  338.                 return end;
  339.         }
  340.        
  341.         public double getWeight ()
  342.         {
  343.                 return weight;
  344.         }
  345.        
  346.         public void modifyWeight (double x)
  347.         {
  348.                 weight += x;
  349.         }
  350. }
  351.  
  352. public class Example
  353. {
  354.         double[] inputs;
  355.         double[] outputs;
  356.        
  357.         public Example (double[] in, double[] out)
  358.         {
  359.                 inputs = in;
  360.                 outputs = out;
  361.         }
  362.        
  363.         public double getInput (int i)
  364.         {
  365.                 return inputs[i];
  366.         }
  367.        
  368.         public double getOutput (int i)
  369.         {
  370.                 return outputs[i];
  371.         }
  372.        
  373.         public int numInputs ()
  374.         {
  375.                 return inputs.length;
  376.         }
  377.        
  378.         public int numOutputs ()
  379.         {
  380.                 return outputs.length;
  381.         }
  382. }