Advertisement
hasib_mo

Neural Network - working code for MNIST dataset

Aug 28th, 2014
303
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 11.43 KB | None | 0 0
  1. /*
  2. import java.io.DataInputStream;
  3. import java.io.FileInputStream;
  4. import java.io.IOException;
  5. */
  6. import java.io.*;
  7. import java.lang.Math;
  8.  
  9.  
  10. class Neuron implements Serializable{
  11.     double weightVector[];
  12.     int numInEdges;
  13.     double y, z;
  14.     double delta;
  15.     double bigDelta[];
  16.  
  17.     public Neuron(int inputEdges)
  18.     {
  19.         numInEdges = inputEdges;
  20.         weightVector = new double[numInEdges];
  21.         bigDelta = new double[numInEdges];
  22.  
  23.         for(int i=0; i<weightVector.length; i++)
  24.         {
  25.             weightVector[i] = (Math.random() - 0.5);
  26.         }
  27.     }
  28. }
  29.  
  30.  
  31. class NeuronLayer implements Serializable{
  32.     int numNeurons;
  33.     Neuron neurons[];
  34.  
  35.     public NeuronLayer(int neuronInThisLayer)
  36.     {
  37.         numNeurons = neuronInThisLayer + 1;
  38.         neurons = new Neuron[numNeurons];
  39.     }
  40. }
  41.  
  42.  
  43. class NeuralNetwork implements Serializable{
  44.     int numInputUnit;
  45.     int numOutputUnit;
  46.     int numHiddenLayer;
  47.     int numNeuronPerHiddenLayer;
  48.     int numTotalLayer;
  49.  
  50.     NeuronLayer neuronlayers[];
  51.  
  52.     int L; //id of last layer
  53.  
  54.     double learningRate, prevError, currentError;
  55.  
  56.  
  57.     int numTrained;
  58.  
  59.  
  60.  
  61.  
  62.     public NeuralNetwork(int _inputUnit, int _outputUnit, int _hiddenLayer, int _neuronPerHiddenLayer)
  63.     {
  64.         learningRate = 1.0;
  65.         prevError = Double.POSITIVE_INFINITY;
  66.         numTrained = 0;
  67.  
  68.  
  69.         numInputUnit = _inputUnit;
  70.         numOutputUnit = _outputUnit;
  71.         numHiddenLayer = _hiddenLayer;
  72.         numNeuronPerHiddenLayer = _neuronPerHiddenLayer;
  73.  
  74.         numTotalLayer = 2 + numHiddenLayer;
  75.  
  76.         L = numTotalLayer - 1;
  77.  
  78.         neuronlayers = new NeuronLayer[numTotalLayer];
  79.  
  80.         neuronlayers[0] = new NeuronLayer(numInputUnit);
  81.         neuronlayers[L] = new NeuronLayer(numOutputUnit);
  82.  
  83.         for(int i=1; i < L; i++)
  84.         {
  85.             neuronlayers[i] = new NeuronLayer(numNeuronPerHiddenLayer);
  86.         }
  87.  
  88.         for(int layer=0; layer<=L; layer++)
  89.         {
  90.             for(int i=0; i<neuronlayers[layer].numNeurons; i++)
  91.             {
  92.                 if(layer == 0)
  93.                 {
  94.                     neuronlayers[layer].neurons[i] = new Neuron(0);
  95.                 }
  96.                 else
  97.                 {
  98.                     neuronlayers[layer].neurons[i] = new Neuron( neuronlayers[layer-1].numNeurons );
  99.                 }
  100.             }
  101.             neuronlayers[layer].neurons[ neuronlayers[layer].numNeurons - 1].y = 1; //bias node
  102.         }
  103.     }
  104.  
  105.     public int feedForward(double x[])
  106.     {
  107.         for(int i=0; i<x.length; i++)
  108.         {
  109.             neuronlayers[0].neurons[i].y = (x[i] / 127.5) - 1.0;
  110.         }
  111.  
  112.         for(int i=1; i<=L; i++)
  113.         {
  114.             feedALevel(neuronlayers[i-1], neuronlayers[i]);
  115.         }
  116.        
  117.  
  118.         double maxVal = -1;
  119.         int idOfMaxVal = -1;
  120.        
  121.         for(int i=0; i<numOutputUnit; i++)
  122.         {
  123.             if(neuronlayers[L].neurons[i].y > maxVal)
  124.             {
  125.                 maxVal = neuronlayers[L].neurons[i].y;
  126.                 idOfMaxVal = i;
  127.             }
  128.         }
  129.         return idOfMaxVal;
  130.  
  131.     }
  132.  
  133.     public void feedALevel(NeuronLayer prevLayer, NeuronLayer curLayer)
  134.     {
  135.         for(int n=0; n < curLayer.numNeurons - 1; n++)
  136.         {
  137.             curLayer.neurons[n].z = 0;
  138.  
  139.             for(int i=0; i<prevLayer.numNeurons; i++)
  140.             {
  141.                 curLayer.neurons[n].z += curLayer.neurons[n].weightVector[i] * prevLayer.neurons[i].y;
  142.             }
  143.  
  144.             curLayer.neurons[n].y = sigmoid( curLayer.neurons[n].z );
  145.  
  146.         }
  147.     }
  148.  
  149.  
  150.     public double calcError(double x[][], double t[][], int m)
  151.     {
  152.         double error = 0;
  153.         double hx;
  154.  
  155.         for(int i=0; i<m; i++)
  156.         {
  157.             feedForward(x[i]);
  158.  
  159.             for(int j = 0; j<10; j++)
  160.             {
  161.                 hx = neuronlayers[L].neurons[j].y;
  162.                 error += ( -t[i][j]*Math.log(hx) - (1 - t[i][j])*Math.log(1 - hx) );
  163.             }
  164.         }
  165.         error = error / m;
  166.  
  167.         System.out.println("cost J = "+ error + " at learningRate = " + learningRate);
  168.  
  169.         return error;
  170.     }
  171.  
  172.  
  173.  
  174.     public void backPropagate(double x[][], double t[][], int m)
  175.     {
  176.         for(int kase = 0; kase<m; kase++)
  177.         {
  178.             feedForward(x[kase]);
  179.             for(int i=0; i<numOutputUnit; i++)
  180.             {
  181.                 neuronlayers[L].neurons[i].delta = neuronlayers[L].neurons[i].y - t[kase][i];
  182.             }
  183.  
  184.             for(int i=L-1; i>0; i--)
  185.             {
  186.                 calcDelta(neuronlayers[i], neuronlayers[i+1]);
  187.             }
  188.  
  189.             for(int i=L; i>0; i--)
  190.             {
  191.                 for(int j=0; j<neuronlayers[i].numNeurons-1; j++)
  192.                 {
  193.                     for(int k=0; k<neuronlayers[i-1].numNeurons; k++)
  194.                     {
  195.                         neuronlayers[i].neurons[j].bigDelta[k] += neuronlayers[i-1].neurons[k].y * neuronlayers[i].neurons[j].delta;
  196.                     }
  197.                 }
  198.             }
  199.         }
  200.  
  201.         for(int i=1; i<=L; i++)
  202.         {
  203.             for(int j=0; j<neuronlayers[i].numNeurons-1; j++)
  204.             {
  205.                 for(int k=0; k<neuronlayers[i].neurons[j].numInEdges; k++)
  206.                 {
  207.                     neuronlayers[i].neurons[j].weightVector[k] -= (learningRate * neuronlayers[i].neurons[j].bigDelta[k] /*+ 0.01*neuronlayers[i].neurons[j].weightVector[k] */ )/(double)m;
  208.                     neuronlayers[i].neurons[j].bigDelta[k] = 0;
  209.                 }
  210.             }
  211.         }
  212.  
  213.  
  214.  
  215.         /************************/
  216.         /*
  217.          * error calculation and fixing learning rate depending on previous error
  218.          * and current error.
  219.         */
  220.  
  221.         currentError = calcError(x, t, m);
  222.         if(currentError < prevError )
  223.         {
  224.             learningRate = learningRate * 1.04;
  225.         }
  226.         else
  227.         {
  228.             learningRate = learningRate * 0.7;
  229.         }
  230.  
  231.         prevError = currentError;
  232.  
  233.         numTrained++;
  234.  
  235.     }
  236.  
  237.  
  238.     public void calcDelta(NeuronLayer curLayer, NeuronLayer forwardLayer)
  239.     {
  240.         for(int i=0; i < curLayer.numNeurons - 1; i++)
  241.         {
  242.             double delta = 0;
  243.             for(int j=0; j<forwardLayer.numNeurons-1; j++)
  244.             {
  245.                 delta += forwardLayer.neurons[j].weightVector[i] * forwardLayer.neurons[j].delta;
  246.             }
  247.             delta = delta * curLayer.neurons[i].y * (1.0 - curLayer.neurons[i].y);
  248.             curLayer.neurons[i].delta = delta;
  249.         }
  250.     }
  251.  
  252.  
  253.     public double sigmoid(double z)
  254.     {
  255.         return 1.0/(1.0 + Math.exp(-z));
  256.     }
  257.  
  258.  
  259.  
  260. }
  261.  
  262.  
  263. class SadaKhata{
  264.     NeuralNetwork neuralnetwork;
  265.     public void saveData()
  266.     {
  267.         try{
  268.             ObjectOutputStream objectoutputstream = new ObjectOutputStream(new FileOutputStream("weights-3.dat"));
  269.             objectoutputstream.writeObject(neuralnetwork);
  270.         }catch(Exception ex)
  271.         {
  272.             System.out.println("<<<<<<<<< COULDN'T WRITE OBJECT >>>>>>>>>>>");
  273.         }
  274.     }
  275.  
  276.     public void loadData(int _inputUnit, int _outputUnit , int _hiddenLayer , int _neuronPerHiddenLayer )
  277.     {
  278.         try{
  279.             ObjectInputStream objectinputstream = new ObjectInputStream(new FileInputStream("weights-3.dat"));
  280.             try{
  281.                 neuralnetwork = (NeuralNetwork) objectinputstream.readObject();
  282.             }catch(Exception ex)
  283.             {
  284.                 System.out.println("<<< COULDN'T READ OBJECT >>>");
  285.                 neuralnetwork = new NeuralNetwork(_inputUnit, _outputUnit, _hiddenLayer, _neuronPerHiddenLayer);
  286.             }
  287.         }catch(Exception ex)
  288.         {
  289.                 System.out.println("<<< COULDN'T READ OBJECT >>>");
  290.                 neuralnetwork = new NeuralNetwork(_inputUnit, _outputUnit, _hiddenLayer, _neuronPerHiddenLayer);
  291.         }
  292.     }
  293. }
  294.  
  295.  
  296.  
  297.  
  298.  
  299.  
  300.  
  301.  
  302.  
  303.  
  304. public class MNISTReader {
  305.  
  306.   /**
  307.    * @param args
  308.    *          args[0]: label file; args[1]: data file.
  309.    * @throws IOException
  310.    */
  311.   public static void main(String[] args) throws IOException {
  312.     DataInputStream labels = new DataInputStream(new FileInputStream("train-labels"));
  313.     DataInputStream images = new DataInputStream(new FileInputStream("train-images"));
  314.     int magicNumber = labels.readInt();
  315.     if (magicNumber != 2049) {
  316.       System.err.println("Label file has wrong magic number: " + magicNumber + " (should be 2049)");
  317.       System.exit(0);
  318.     }
  319.     magicNumber = images.readInt();
  320.     if (magicNumber != 2051) {
  321.       System.err.println("Image file has wrong magic number: " + magicNumber + " (should be 2051)");
  322.       System.exit(0);
  323.     }
  324.     int numLabels = labels.readInt();
  325.     int numImages = images.readInt();
  326.     int numRows = images.readInt();
  327.     int numCols = images.readInt();
  328.     if (numLabels != numImages) {
  329.       System.err.println("Image file and label file do not contain the same number of entries.");
  330.       System.err.println("  Label file contains: " + numLabels);
  331.       System.err.println("  Image file contains: " + numImages);
  332.       System.exit(0);
  333.     }
  334.  
  335.     long start = System.currentTimeMillis();
  336.     int numLabelsRead = 0;
  337.     int numImagesRead = 0;
  338.  
  339.     /******************************/
  340.  
  341.     double[][] x = new double[60000][784];
  342.     double[][] t = new double[60000][10];
  343.  
  344.     int kaseno = 0;
  345.  
  346.  
  347.  
  348.     while (labels.available() > 0 && numLabelsRead < numLabels) {
  349.       byte label = labels.readByte();
  350.       numLabelsRead++;
  351.       int[][] image = new int[numCols][numRows];
  352.  
  353.  
  354.       for (int colIdx = 0; colIdx < numCols; colIdx++) {
  355.         for (int rowIdx = 0; rowIdx < numRows; rowIdx++) {
  356.           image[colIdx][rowIdx] = images.readUnsignedByte();
  357.         }
  358.       }
  359.       numImagesRead++;
  360.  
  361.       int m = 0;
  362.       for(int i=0; i<28; i++)
  363.       {
  364.         for(int j=0; j<28; j++)
  365.         {
  366.             x[kaseno][m] = image[i][j];
  367.             m++;
  368.         }
  369.       }
  370.  
  371.       for(int i=0; i<10; i++) t[kaseno][i] = 0;
  372.       t[kaseno][ label ] = 1;
  373.  
  374.       kaseno++;
  375.  
  376.     }
  377.  
  378.  
  379.  
  380.  
  381.     SadaKhata sk = new SadaKhata();
  382.     sk.loadData(784, 10, 2, 50);
  383.  
  384.    
  385.  
  386.     for(int loop = 0; loop<10; loop++)
  387.     {
  388.         System.out.print(loop+1 + ". ");
  389.         sk.neuralnetwork.backPropagate(x, t, 60000);
  390.     }
  391.  
  392.     System.out.println("Total number of backPropagation = " + sk.neuralnetwork.numTrained);
  393.  
  394.     sk.saveData();
  395.    
  396.  
  397.  
  398.  
  399.  
  400.    
  401.  
  402.     labels = new DataInputStream(new FileInputStream("test-labels"));
  403.     images = new DataInputStream(new FileInputStream("test-images"));
  404.  
  405.  
  406.     magicNumber = labels.readInt();
  407.     if (magicNumber != 2049) {
  408.       System.err.println("Label file has wrong magic number: " + magicNumber + " (should be 2049)");
  409.       System.exit(0);
  410.     }
  411.     magicNumber = images.readInt();
  412.     if (magicNumber != 2051) {
  413.       System.err.println("Image file has wrong magic number: " + magicNumber + " (should be 2051)");
  414.       System.exit(0);
  415.     }
  416.     numLabels = labels.readInt();
  417.     numImages = images.readInt();
  418.     numRows = images.readInt();
  419.     numCols = images.readInt();
  420.     if (numLabels != numImages) {
  421.       System.err.println("Image file and label file do not contain the same number of entries.");
  422.       System.err.println("  Label file contains: " + numLabels);
  423.       System.err.println("  Image file contains: " + numImages);
  424.       System.exit(0);
  425.     }
  426.  
  427.     //long start = System.currentTimeMillis();
  428.     numLabelsRead = 0;
  429.     numImagesRead = 0;
  430.  
  431.  
  432.  
  433.  
  434.  
  435.  
  436.     while (labels.available() > 0 && numLabelsRead < numLabels) {
  437.       byte label = labels.readByte();
  438.       numLabelsRead++;
  439.       int[][] image = new int[numCols][numRows];
  440.  
  441.  
  442.       for (int colIdx = 0; colIdx < numCols; colIdx++) {
  443.         for (int rowIdx = 0; rowIdx < numRows; rowIdx++) {
  444.           image[colIdx][rowIdx] = images.readUnsignedByte();
  445.         }
  446.       }
  447.       numImagesRead++;
  448.  
  449.       int m = 0;
  450.       for(int i=0; i<28; i++)
  451.       {
  452.         for(int j=0; j<28; j++)
  453.         {
  454.                 x[numImagesRead-1][m] = image[i][j];
  455.                 m++;
  456.         }
  457.       }
  458.  
  459.       for(int i=0; i<10; i++)
  460.       {
  461.         t[numLabelsRead-1][i] = 0;
  462.       }
  463.  
  464.       t[numImagesRead-1][label] = 1;
  465.     }
  466.  
  467.  
  468.     int success = 0;
  469.  
  470.     int cntLabels[] = new int[10];
  471.  
  472.     for(int i=0; i<10000; i++)
  473.     {
  474.         int outputLabel = sk.neuralnetwork.feedForward(x[i]);
  475.         if(t[i][outputLabel] == 1)
  476.         {
  477.             success++;
  478.         }
  479.         cntLabels[outputLabel]++;
  480.     }
  481.  
  482.  
  483.  
  484.  
  485.     System.out.println("Total success = " + success);
  486.  
  487.     for(int i=0; i<10; i++)
  488.     {
  489.         System.out.println("numLabels["+ i + "] = "+cntLabels[i]);
  490.     }
  491.  
  492.     System.out.println();
  493.     long end = System.currentTimeMillis();
  494.     long elapsed = end - start;
  495.     long minutes = elapsed / (1000 * 60);
  496.     long seconds = (elapsed / 1000) - (minutes * 60);
  497.     System.out.println("Read " + numLabelsRead + " samples in " + minutes + " m " + seconds + " s ");
  498.   }
  499.  
  500. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement