tzmtn

Java BackPropLearning issue

Apr 18th, 2012
47
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 7.65 KB | None | 0 0
  1. public class Tester
  2. {
  3.     public static void main(String[] args)
  4.     {
  5.         double[][] inputs = {
  6.             {0.547361, -2.04845, -2.71647},
  7.             {1.9708, -1.16141, -0.0485735},
  8.             {-3.18799, 2.97068, 0.26499},
  9.             {-0.498425, -1.04703, 1.61744}
  10.         };
  11.  
  12.         double[][] outputs = {
  13.             {1.0, 0.0, 0.0, 0.0},
  14.             {0.0, 1.0, 0.0, 0.0},
  15.             {0.0, 0.0, 1.0, 0.0},
  16.             {0.0, 0.0, 0.0, 1.0}
  17.         };
  18.        
  19.         Example[] examples = new Example[inputs.length];
  20.         for (int i = 0; i < examples.length; i++)
  21.             examples[i] = new Example(inputs[i], outputs[i]);
  22.        
  23.         int[] hiddenLayers = {4,4};  // sizes of hidden layers
  24.         Network net = new Network();
  25.        
  26.         net.backPropLearning(examples, hiddenLayers, 10000, 1.0);
  27.     }
  28. }
  29.  
  30.  
  31. public class Network
  32. {
  33.     private Layer[] layers;
  34.     static double reportMultiplier = -1.0;
  35.    
  36.     private void setLayer (Layer layer, int i)
  37.     {
  38.         layers[i] = layer;
  39.     }
  40.    
  41.     private void createLayers (int numInputs, int numOutputs, int[] hiddenLayers)
  42.     {
  43.         layers = new Layer[hiddenLayers.length + 2];
  44.        
  45.         Layer inputLayer = new Layer(numInputs);
  46.         setLayer(inputLayer, 0);
  47.        
  48.         for (int i = 0; i < hiddenLayers.length; i++)
  49.             setLayer(new Layer(layers[i], hiddenLayers[i]), i + 1);
  50.        
  51.         Layer outputLayer = new Layer(layers[hiddenLayers.length], numOutputs);
  52.         setLayer(outputLayer, layers.length - 1);
  53.        
  54.         for (int i = 0; i < layers.length; i++)
  55.             layers[i].createNeurons();
  56.     }
  57.    
  58.     private double activation (double x)
  59.     {
  60.         return 1.0 / (1.0 + Math.exp(-x)); // sigmoid
  61.     }
  62.    
  63.     private double activationDerivative (double x)
  64.     {
  65.         double exp = Math.exp(x);
  66.         return exp / Math.pow((exp + 1.0), 2); // d/dx sigmoid
  67.     }
  68.    
  69.     private void printReport (int n, Example[] examples)
  70.     {
  71.         String s = "";
  72.         double errorSum = 0.0;
  73.         int m = layers.length - 1;
  74.         for (int i = 0; i < layers[m].getSize(); i++)
  75.             for (int j = 0; j < examples.length; j++)
  76.                 errorSum += Math.abs(examples[j].getOutput(i) - (layers[m].getNeuron(i)).getOutput());
  77.        
  78.         s = "n:" + n;
  79.         while (s.length() < 10)
  80.             s += " ";
  81.        
  82.         s += "error:" + (float) errorSum;
  83.         while (s.length() < 30)
  84.             s += " ";
  85.        
  86.         System.out.print(s);
  87.        
  88.         if (reportMultiplier < 0)
  89.             reportMultiplier = 50 / errorSum;
  90.        
  91.         for (int i = 0; i < errorSum * reportMultiplier; i++)
  92.             System.out.print("|");
  93.        
  94.         System.out.println();
  95.     }
  96.    
  97.     public void backPropLearning (Example[] examples, int[] hiddenLayers, int steps, double learningFactor)
  98.     {
  99.         createLayers(examples[0].numInputs(), examples[0].numOutputs(), hiddenLayers);
  100.         int n = 0;
  101.         int numPlots = 50;
  102.         int plotStep = Math.max(1, steps / numPlots);
  103.        
  104.         while (n < steps) {
  105.             for (int e = 0; e < examples.length; e++) {
  106.                 Example example = examples[e];
  107.                
  108.                 // Propagate the inputs forward to compute the output
  109.                 for (int i = 0; i < layers[0].getSize(); i++)
  110.                     layers[0].getNeuron(i).setOutput(example.getInput(i));
  111.                
  112.                 for (int i = 1; i < layers.length; i++) {
  113.                     for (int j = 0; j < layers[i].getSize(); j++) {
  114.                         Neuron nj = layers[i].getNeuron(j);
  115.                         double in = 0.0;
  116.                         for (int k = 0; k < nj.numInputs(); k++) {
  117.                             Axon akj = nj.getInput(k);
  118.                             Neuron nk = akj.getStart();
  119.                             in += akj.getWeight() * nk.getOutput();
  120.                         }
  121.                         nj.setInput(in);
  122.                         nj.setOutput(activation(in));
  123.                     }
  124.                 }
  125.                
  126.                 // Propagate deltas backward from output to input layer
  127.                 int m = layers.length - 1;
  128.                 for (int i = 0; i < layers[m].getSize(); i++) {
  129.                     Neuron ni = layers[m].getNeuron(i);
  130.                     double errori = example.getOutput(i) - ni.getOutput();
  131.                     ni.setDelta(activationDerivative(ni.getInput()) * errori);
  132.                 }
  133.                
  134.                 for (int i = m - 1; i >= 0; i--) {
  135.                     for (int j = 0; j < layers[i].getSize(); j++) {
  136.                         Neuron nj = layers[i].getNeuron(j);
  137.                         double delta = 0.0;
  138.                         for (int k = 0; k < nj.numOutputs(); k++) {
  139.                             Axon ajk = nj.getOutput(k);
  140.                             Neuron nk = ajk.getEnd();
  141.                             delta += ajk.getWeight() * nk.getDelta();
  142.                         }
  143.                         delta *= activationDerivative(nj.getInput());
  144.                         nj.setDelta(delta);
  145.                     }
  146.                 }
  147.                
  148.                 // Update every weight in network using deltas
  149.                 for (int i = 0; i < layers.length - 1; i++) {
  150.                     for (int j = 0; j < layers[i].getSize(); j++) {
  151.                         Neuron nj = layers[i].getNeuron(j);
  152.                        
  153.                         for (int k = 0; k < nj.numOutputs(); k++) {
  154.                             Axon ajk = nj.getOutput(k);
  155.                             Neuron nk = ajk.getEnd();
  156.                             ajk.modifyWeight(learningFactor * nk.getDelta() * nj.getOutput());
  157.                             ajk.modifyWeight(learningFactor * nj.getDelta() * nk.getOutput());
  158.                         }
  159.                     }
  160.                 }
  161.             }
  162.            
  163.             if (n % plotStep == 0)
  164.                 printReport(n, examples);
  165.            
  166.             n++;
  167.         }
  168.        
  169.         printReport(n, examples);
  170.     }
  171. }
  172.  
  173.  
  174. public class Layer
  175. {
  176.     private Neuron[] neurons;
  177.     private Layer previous;
  178.     private Layer next;
  179.    
  180.     public Layer (int n)
  181.     {
  182.         neurons = new Neuron[n];
  183.     }
  184.    
  185.     public Layer (Layer l, int n)
  186.     {
  187.         this(n);
  188.        
  189.         previous = l;
  190.         previous.next = this;
  191.     }
  192.    
  193.     public void createNeurons ()
  194.     {
  195.         int n = previous == null ? 0 : previous.neurons.length;
  196.         int m = next == null ? 0 : next.neurons.length;
  197.        
  198.         for (int i = 0; i < neurons.length; i++) {
  199.             neurons[i] = new Neuron(n, m);
  200.             for (int j = 0; j < n; j++)
  201.                 neurons[i].addInput(previous.neurons[j]);
  202.         }
  203.     }
  204.    
  205.     public int getSize ()
  206.     {
  207.         return neurons.length;
  208.     }
  209.    
  210.     public Neuron getNeuron (int i)
  211.     {
  212.         return neurons[i];
  213.     }
  214. }
  215.  
  216.  
  217. public class Neuron
  218. {
  219.     private Axon[] inputs;
  220.     private Axon[] outputs;
  221.     private double input;
  222.     private double output;
  223.     private double delta;
  224.    
  225.     public Neuron () // dummy constructor
  226.     {
  227.         inputs = new Axon[0];
  228.         outputs = new Axon[1];
  229.         input = 1.0;
  230.         output = 1.0;
  231.         delta = 0.0;
  232.     }
  233.    
  234.     public Neuron (int n, int m)
  235.     {
  236.         this();
  237.         n += 1; // add the dummy input
  238.         inputs = new Axon[n];
  239.         outputs = new Axon[m];
  240.        
  241.         Neuron dummy = new Neuron();
  242.         addInput(dummy);
  243.     }
  244.    
  245.     private void append (Axon[] arr, Axon a)
  246.     {
  247.         int i = 0;
  248.         while (i < arr.length && arr[i] != null)
  249.             i++;
  250.         if (i < arr.length)
  251.             arr[i] = a;
  252.     }
  253.    
  254.     private void addOutput (Axon a)
  255.     {
  256.         append(outputs, a);
  257.     }
  258.    
  259.     public void addInput (Neuron n)
  260.     {
  261.         Axon a = new Axon(n, this);
  262.         append(inputs, a);
  263.         n.addOutput(a);
  264.     }
  265.    
  266.     public double getInput ()
  267.     {
  268.         return input;
  269.     }
  270.    
  271.     public void setInput (double in)
  272.     {
  273.         input = in;
  274.     }
  275.    
  276.     public double getOutput ()
  277.     {
  278.         return output;
  279.     }
  280.    
  281.     public void setOutput (double out)
  282.     {
  283.         output = out;
  284.     }
  285.    
  286.     public double getDelta ()
  287.     {
  288.         return delta;
  289.     }
  290.    
  291.     public void setDelta (double newDelta)
  292.     {
  293.         delta = newDelta;
  294.     }
  295.    
  296.     public int numInputs ()
  297.     {
  298.         return inputs.length;
  299.     }
  300.    
  301.     public int numOutputs ()
  302.     {
  303.         return outputs.length;
  304.     }
  305.    
  306.     public Axon getInput (int i)
  307.     {
  308.         return inputs[i];
  309.     }
  310.    
  311.     public Axon getOutput (int i)
  312.     {
  313.         return outputs[i];
  314.     }
  315. }
  316.  
  317.  
  318. public class Axon
  319. {
  320.     private double weight;
  321.     private Neuron start;
  322.     private Neuron end;
  323.    
  324.     public Axon (Neuron s, Neuron e)
  325.     {
  326.         start = s;
  327.         end = e;
  328.         weight = 2.0 * Math.random() - 1.0; // random initial weight between -1 and 1
  329.     }
  330.    
  331.     public Neuron getStart ()
  332.     {
  333.         return start;
  334.     }
  335.    
  336.     public Neuron getEnd ()
  337.     {
  338.         return end;
  339.     }
  340.    
  341.     public double getWeight ()
  342.     {
  343.         return weight;
  344.     }
  345.    
  346.     public void modifyWeight (double x)
  347.     {
  348.         weight += x;
  349.     }
  350. }
  351.  
  352. public class Example
  353. {
  354.     double[] inputs;
  355.     double[] outputs;
  356.    
  357.     public Example (double[] in, double[] out)
  358.     {
  359.         inputs = in;
  360.         outputs = out;
  361.     }
  362.    
  363.     public double getInput (int i)
  364.     {
  365.         return inputs[i];
  366.     }
  367.    
  368.     public double getOutput (int i)
  369.     {
  370.         return outputs[i];
  371.     }
  372.    
  373.     public int numInputs ()
  374.     {
  375.         return inputs.length;
  376.     }
  377.    
  378.     public int numOutputs ()
  379.     {
  380.         return outputs.length;
  381.     }
  382. }
Add Comment
Please, Sign In to add comment