Advertisement
Guest User

NeuralNetwork

a guest
Aug 21st, 2019
82
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Java 4.92 KB | None | 0 0
  1. package dl;
  2.  
  3.  
  4. import org.apache.commons.math3.linear.MatrixUtils;
  5. import org.apache.commons.math3.linear.RealMatrix;
  6.  
  7. import java.util.List;
  8. import java.util.concurrent.ThreadLocalRandom;
  9.  
  10.  
  11.  
  12. public class NeuralNetwork {
  13.  
  14.     private int input_nodes_n;
  15.     private int hidden_nodes_n;
  16.     private int output_nodes_n;
  17.     private double learning_rate;
  18.     private RealMatrix input_hiddenMatrix;
  19.     private RealMatrix hidden_outputMatrix;
  20.     private RealMatrix final_inputs;
  21.  
  22.     public NeuralNetwork(int input_nodes_n, int hidden_nodes_n, int output_nodes_n, double learning_rate){
  23.         this.input_nodes_n = input_nodes_n;
  24.         this.hidden_nodes_n = hidden_nodes_n;
  25.         this.output_nodes_n = output_nodes_n;
  26.         this.learning_rate = learning_rate;
  27.         input_hiddenMatrix = generateMatrixWithRandomWeights (hidden_nodes_n, input_nodes_n);
  28.         hidden_outputMatrix = generateMatrixWithRandomWeights (output_nodes_n, hidden_nodes_n);
  29.     }
  30.  
  31.     public void trainNeuralNetwork(List inputs_list, List targets_list){
  32.         RealMatrix input_data_vector = listToVector(inputs_list).transpose();
  33.         RealMatrix target_data_vector = listToVector(targets_list).transpose();
  34.         RealMatrix hidden_inputs = input_hiddenMatrix.multiply(input_data_vector); // multiply : vector product
  35.         RealMatrix hidden_outputs = activationFunction(hidden_inputs);
  36.         RealMatrix final_outputs = activationFunction(final_inputs);
  37.         RealMatrix output_errors = target_data_vector.subtract(final_outputs);
  38.         RealMatrix hidden_errors = hidden_outputMatrix.transpose().multiply(output_errors);
  39.         RealMatrix auxM = calculteDescentGradient_ErrorByWeights (output_errors,final_outputs);
  40.         RealMatrix aux2 = auxM.multiply(hidden_outputs.transpose());
  41.         aux2 = aux2.scalarMultiply(learning_rate); // scalar multiply : mulitply each matrix element by a scalar
  42.         hidden_outputMatrix = hidden_outputMatrix.add(aux2);
  43.         auxM = calculteDescentGradient_ErrorByWeights (hidden_errors,hidden_outputs);
  44.         aux2 = auxM.multiply(input_data_vector.transpose());
  45.         aux2 = aux2.scalarMultiply(learning_rate);
  46.         input_hiddenMatrix =  input_hiddenMatrix.add(aux2);
  47.     }
  48.  
  49.     public static void printMatrix(RealMatrix matrix){
  50.         int rows = matrix.getRowDimension(), columns = matrix.getColumnDimension();
  51.         System.out.println("\n\n\n START HERE \n");
  52.         for(int row = 0; row < rows; row++){
  53.             System.out.println("\n");
  54.             for(int column = 0; column < columns; column++){
  55.                 System.out.print(matrix.getEntry(row, column) + " | ");
  56.             }
  57.         }
  58.     }
  59.  
  60.     public RealMatrix applyNeuralNetwork(List inputs_list){
  61.         RealMatrix input_data_matrix = listToVector(inputs_list).transpose();
  62.         RealMatrix hidden_inputs = input_hiddenMatrix.multiply(input_data_matrix);
  63.         RealMatrix hidden_outputs = activationFunction(hidden_inputs);
  64.         RealMatrix final_inputs = hidden_outputMatrix.multiply(hidden_outputs);
  65.         RealMatrix final_outputs = activationFunction(final_inputs);
  66.         return final_outputs;
  67.     }
  68.  
  69.     private RealMatrix calculteDescentGradient_ErrorByWeights(RealMatrix A, RealMatrix B){
  70.         int i, j, rows = A.getRowDimension(), columns = A.getColumnDimension();
  71.         RealMatrix result = MatrixUtils.createRealMatrix(rows, columns);
  72.         for(i = 0; i < rows; i++){
  73.             for(j = 0; j < columns; j++){
  74.                 double a = A.getEntry(i,j), b = B.getEntry(i,j);
  75.                 result.setEntry(i, j, a * b * (1.0 - b));
  76.             }
  77.         }
  78.         return result;
  79.     }
  80.  
  81.     private RealMatrix activationFunction(RealMatrix matriz){
  82.         int i, j, rows = matriz.getRowDimension(), columns = matriz.getColumnDimension();
  83.         RealMatrix result = MatrixUtils.createRealMatrix(rows, columns);
  84.         for(i = 0; i < rows; i++){
  85.             for(j = 0; j < columns; j++){
  86.                 result.setEntry(i, j, sigmoidFunction(matriz.getEntry(i,j)));
  87.             }
  88.         }
  89.         return result;
  90.     }
  91.  
  92.     private double sigmoidFunction(double x) {
  93.         return (1.0 / (1.0 + Math.exp(-x)));
  94.     }
  95.  
  96.     private RealMatrix listToVector(List list){
  97.         int list_size = list.size();
  98.         RealMatrix vector = MatrixUtils.createRealMatrix(1, list_size);
  99.         for(int column = 0; column < list_size; column++){
  100.             vector.setEntry(0, column, (double)list.get(column));
  101.         }
  102.         return vector;
  103.     }
  104.  
  105.     private RealMatrix generateMatrixWithRandomWeights (int rows, int columns){
  106.         RealMatrix random_matrix = MatrixUtils.createRealMatrix(rows, columns);
  107.         for(int row = 0; row < rows; row++){
  108.             for(int column = 0; column < columns; column++){
  109.                 random_matrix.setEntry (row, column, ThreadLocalRandom.current() .nextDouble(-0.5,0.5));
  110.             }
  111.         }
  112.         return random_matrix;
  113.     }
  114.  
  115. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement