Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- /*
- import java.io.DataInputStream;
- import java.io.FileInputStream;
- import java.io.IOException;
- */
- import java.io.*;
- import java.lang.Math;
- class Neuron implements Serializable{
- double weightVector[];
- int numInEdges;
- double y, z;
- double delta;
- double bigDelta[];
- public Neuron(int inputEdges)
- {
- numInEdges = inputEdges;
- weightVector = new double[numInEdges];
- bigDelta = new double[numInEdges];
- for(int i=0; i<weightVector.length; i++)
- {
- weightVector[i] = (Math.random() - 0.5);
- }
- }
- }
- class NeuronLayer implements Serializable{
- int numNeurons;
- Neuron neurons[];
- public NeuronLayer(int neuronInThisLayer)
- {
- numNeurons = neuronInThisLayer + 1;
- neurons = new Neuron[numNeurons];
- }
- }
- class NeuralNetwork implements Serializable{
- int numInputUnit;
- int numOutputUnit;
- int numHiddenLayer;
- int numNeuronPerHiddenLayer;
- int numTotalLayer;
- NeuronLayer neuronlayers[];
- int L; //id of last layer
- double learningRate, prevError, currentError;
- int numTrained;
- public NeuralNetwork(int _inputUnit, int _outputUnit, int _hiddenLayer, int _neuronPerHiddenLayer)
- {
- learningRate = 1.0;
- prevError = Double.POSITIVE_INFINITY;
- numTrained = 0;
- numInputUnit = _inputUnit;
- numOutputUnit = _outputUnit;
- numHiddenLayer = _hiddenLayer;
- numNeuronPerHiddenLayer = _neuronPerHiddenLayer;
- numTotalLayer = 2 + numHiddenLayer;
- L = numTotalLayer - 1;
- neuronlayers = new NeuronLayer[numTotalLayer];
- neuronlayers[0] = new NeuronLayer(numInputUnit);
- neuronlayers[L] = new NeuronLayer(numOutputUnit);
- for(int i=1; i < L; i++)
- {
- neuronlayers[i] = new NeuronLayer(numNeuronPerHiddenLayer);
- }
- for(int layer=0; layer<=L; layer++)
- {
- for(int i=0; i<neuronlayers[layer].numNeurons; i++)
- {
- if(layer == 0)
- {
- neuronlayers[layer].neurons[i] = new Neuron(0);
- }
- else
- {
- neuronlayers[layer].neurons[i] = new Neuron( neuronlayers[layer-1].numNeurons );
- }
- }
- neuronlayers[layer].neurons[ neuronlayers[layer].numNeurons - 1].y = 1; //bias node
- }
- }
- public int feedForward(double x[])
- {
- for(int i=0; i<x.length; i++)
- {
- neuronlayers[0].neurons[i].y = (x[i] / 127.5) - 1.0;
- }
- for(int i=1; i<=L; i++)
- {
- feedALevel(neuronlayers[i-1], neuronlayers[i]);
- }
- double maxVal = -1;
- int idOfMaxVal = -1;
- for(int i=0; i<numOutputUnit; i++)
- {
- if(neuronlayers[L].neurons[i].y > maxVal)
- {
- maxVal = neuronlayers[L].neurons[i].y;
- idOfMaxVal = i;
- }
- }
- return idOfMaxVal;
- }
- public void feedALevel(NeuronLayer prevLayer, NeuronLayer curLayer)
- {
- for(int n=0; n < curLayer.numNeurons - 1; n++)
- {
- curLayer.neurons[n].z = 0;
- for(int i=0; i<prevLayer.numNeurons; i++)
- {
- curLayer.neurons[n].z += curLayer.neurons[n].weightVector[i] * prevLayer.neurons[i].y;
- }
- curLayer.neurons[n].y = sigmoid( curLayer.neurons[n].z );
- }
- }
- public double calcError(double x[][], double t[][], int m)
- {
- double error = 0;
- double hx;
- for(int i=0; i<m; i++)
- {
- feedForward(x[i]);
- for(int j = 0; j<10; j++)
- {
- hx = neuronlayers[L].neurons[j].y;
- error += ( -t[i][j]*Math.log(hx) - (1 - t[i][j])*Math.log(1 - hx) );
- }
- }
- error = error / m;
- System.out.println("cost J = "+ error + " at learningRate = " + learningRate);
- return error;
- }
- public void backPropagate(double x[][], double t[][], int m)
- {
- for(int kase = 0; kase<m; kase++)
- {
- feedForward(x[kase]);
- for(int i=0; i<numOutputUnit; i++)
- {
- neuronlayers[L].neurons[i].delta = neuronlayers[L].neurons[i].y - t[kase][i];
- }
- for(int i=L-1; i>0; i--)
- {
- calcDelta(neuronlayers[i], neuronlayers[i+1]);
- }
- for(int i=L; i>0; i--)
- {
- for(int j=0; j<neuronlayers[i].numNeurons-1; j++)
- {
- for(int k=0; k<neuronlayers[i-1].numNeurons; k++)
- {
- neuronlayers[i].neurons[j].bigDelta[k] += neuronlayers[i-1].neurons[k].y * neuronlayers[i].neurons[j].delta;
- }
- }
- }
- }
- for(int i=1; i<=L; i++)
- {
- for(int j=0; j<neuronlayers[i].numNeurons-1; j++)
- {
- for(int k=0; k<neuronlayers[i].neurons[j].numInEdges; k++)
- {
- neuronlayers[i].neurons[j].weightVector[k] -= (learningRate * neuronlayers[i].neurons[j].bigDelta[k] /*+ 0.01*neuronlayers[i].neurons[j].weightVector[k] */ )/(double)m;
- neuronlayers[i].neurons[j].bigDelta[k] = 0;
- }
- }
- }
- /************************/
- /*
- * error calculation and fixing learning rate depending on previous error
- * and current error.
- */
- currentError = calcError(x, t, m);
- if(currentError < prevError )
- {
- learningRate = learningRate * 1.04;
- }
- else
- {
- learningRate = learningRate * 0.7;
- }
- prevError = currentError;
- numTrained++;
- }
- public void calcDelta(NeuronLayer curLayer, NeuronLayer forwardLayer)
- {
- for(int i=0; i < curLayer.numNeurons - 1; i++)
- {
- double delta = 0;
- for(int j=0; j<forwardLayer.numNeurons-1; j++)
- {
- delta += forwardLayer.neurons[j].weightVector[i] * forwardLayer.neurons[j].delta;
- }
- delta = delta * curLayer.neurons[i].y * (1.0 - curLayer.neurons[i].y);
- curLayer.neurons[i].delta = delta;
- }
- }
- public double sigmoid(double z)
- {
- return 1.0/(1.0 + Math.exp(-z));
- }
- }
- class SadaKhata{
- NeuralNetwork neuralnetwork;
- public void saveData()
- {
- try{
- ObjectOutputStream objectoutputstream = new ObjectOutputStream(new FileOutputStream("weights-3.dat"));
- objectoutputstream.writeObject(neuralnetwork);
- }catch(Exception ex)
- {
- System.out.println("<<<<<<<<< COULDN'T WRITE OBJECT >>>>>>>>>>>");
- }
- }
- public void loadData(int _inputUnit, int _outputUnit , int _hiddenLayer , int _neuronPerHiddenLayer )
- {
- try{
- ObjectInputStream objectinputstream = new ObjectInputStream(new FileInputStream("weights-3.dat"));
- try{
- neuralnetwork = (NeuralNetwork) objectinputstream.readObject();
- }catch(Exception ex)
- {
- System.out.println("<<< COULDN'T READ OBJECT >>>");
- neuralnetwork = new NeuralNetwork(_inputUnit, _outputUnit, _hiddenLayer, _neuronPerHiddenLayer);
- }
- }catch(Exception ex)
- {
- System.out.println("<<< COULDN'T READ OBJECT >>>");
- neuralnetwork = new NeuralNetwork(_inputUnit, _outputUnit, _hiddenLayer, _neuronPerHiddenLayer);
- }
- }
- }
- public class MNISTReader {
- /**
- * @param args
- * args[0]: label file; args[1]: data file.
- * @throws IOException
- */
- public static void main(String[] args) throws IOException {
- DataInputStream labels = new DataInputStream(new FileInputStream("train-labels"));
- DataInputStream images = new DataInputStream(new FileInputStream("train-images"));
- int magicNumber = labels.readInt();
- if (magicNumber != 2049) {
- System.err.println("Label file has wrong magic number: " + magicNumber + " (should be 2049)");
- System.exit(0);
- }
- magicNumber = images.readInt();
- if (magicNumber != 2051) {
- System.err.println("Image file has wrong magic number: " + magicNumber + " (should be 2051)");
- System.exit(0);
- }
- int numLabels = labels.readInt();
- int numImages = images.readInt();
- int numRows = images.readInt();
- int numCols = images.readInt();
- if (numLabels != numImages) {
- System.err.println("Image file and label file do not contain the same number of entries.");
- System.err.println(" Label file contains: " + numLabels);
- System.err.println(" Image file contains: " + numImages);
- System.exit(0);
- }
- long start = System.currentTimeMillis();
- int numLabelsRead = 0;
- int numImagesRead = 0;
- /******************************/
- double[][] x = new double[60000][784];
- double[][] t = new double[60000][10];
- int kaseno = 0;
- while (labels.available() > 0 && numLabelsRead < numLabels) {
- byte label = labels.readByte();
- numLabelsRead++;
- int[][] image = new int[numCols][numRows];
- for (int colIdx = 0; colIdx < numCols; colIdx++) {
- for (int rowIdx = 0; rowIdx < numRows; rowIdx++) {
- image[colIdx][rowIdx] = images.readUnsignedByte();
- }
- }
- numImagesRead++;
- int m = 0;
- for(int i=0; i<28; i++)
- {
- for(int j=0; j<28; j++)
- {
- x[kaseno][m] = image[i][j];
- m++;
- }
- }
- for(int i=0; i<10; i++) t[kaseno][i] = 0;
- t[kaseno][ label ] = 1;
- kaseno++;
- }
- SadaKhata sk = new SadaKhata();
- sk.loadData(784, 10, 2, 50);
- for(int loop = 0; loop<10; loop++)
- {
- System.out.print(loop+1 + ". ");
- sk.neuralnetwork.backPropagate(x, t, 60000);
- }
- System.out.println("Total number of backPropagation = " + sk.neuralnetwork.numTrained);
- sk.saveData();
- labels = new DataInputStream(new FileInputStream("test-labels"));
- images = new DataInputStream(new FileInputStream("test-images"));
- magicNumber = labels.readInt();
- if (magicNumber != 2049) {
- System.err.println("Label file has wrong magic number: " + magicNumber + " (should be 2049)");
- System.exit(0);
- }
- magicNumber = images.readInt();
- if (magicNumber != 2051) {
- System.err.println("Image file has wrong magic number: " + magicNumber + " (should be 2051)");
- System.exit(0);
- }
- numLabels = labels.readInt();
- numImages = images.readInt();
- numRows = images.readInt();
- numCols = images.readInt();
- if (numLabels != numImages) {
- System.err.println("Image file and label file do not contain the same number of entries.");
- System.err.println(" Label file contains: " + numLabels);
- System.err.println(" Image file contains: " + numImages);
- System.exit(0);
- }
- //long start = System.currentTimeMillis();
- numLabelsRead = 0;
- numImagesRead = 0;
- while (labels.available() > 0 && numLabelsRead < numLabels) {
- byte label = labels.readByte();
- numLabelsRead++;
- int[][] image = new int[numCols][numRows];
- for (int colIdx = 0; colIdx < numCols; colIdx++) {
- for (int rowIdx = 0; rowIdx < numRows; rowIdx++) {
- image[colIdx][rowIdx] = images.readUnsignedByte();
- }
- }
- numImagesRead++;
- int m = 0;
- for(int i=0; i<28; i++)
- {
- for(int j=0; j<28; j++)
- {
- x[numImagesRead-1][m] = image[i][j];
- m++;
- }
- }
- for(int i=0; i<10; i++)
- {
- t[numLabelsRead-1][i] = 0;
- }
- t[numImagesRead-1][label] = 1;
- }
- int success = 0;
- int cntLabels[] = new int[10];
- for(int i=0; i<10000; i++)
- {
- int outputLabel = sk.neuralnetwork.feedForward(x[i]);
- if(t[i][outputLabel] == 1)
- {
- success++;
- }
- cntLabels[outputLabel]++;
- }
- System.out.println("Total success = " + success);
- for(int i=0; i<10; i++)
- {
- System.out.println("numLabels["+ i + "] = "+cntLabels[i]);
- }
- System.out.println();
- long end = System.currentTimeMillis();
- long elapsed = end - start;
- long minutes = elapsed / (1000 * 60);
- long seconds = (elapsed / 1000) - (minutes * 60);
- System.out.println("Read " + numLabelsRead + " samples in " + minutes + " m " + seconds + " s ");
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement