Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package dl;
- import org.apache.commons.math3.linear.MatrixUtils;
- import org.apache.commons.math3.linear.RealMatrix;
- import java.util.List;
- import java.util.concurrent.ThreadLocalRandom;
- public class NeuralNetwork {
- private int input_nodes_n;
- private int hidden_nodes_n;
- private int output_nodes_n;
- private double learning_rate;
- private RealMatrix input_hiddenMatrix;
- private RealMatrix hidden_outputMatrix;
- private RealMatrix final_inputs;
- public NeuralNetwork(int input_nodes_n, int hidden_nodes_n, int output_nodes_n, double learning_rate){
- this.input_nodes_n = input_nodes_n;
- this.hidden_nodes_n = hidden_nodes_n;
- this.output_nodes_n = output_nodes_n;
- this.learning_rate = learning_rate;
- input_hiddenMatrix = generateMatrixWithRandomWeights (hidden_nodes_n, input_nodes_n);
- hidden_outputMatrix = generateMatrixWithRandomWeights (output_nodes_n, hidden_nodes_n);
- }
- public void trainNeuralNetwork(List inputs_list, List targets_list){
- RealMatrix input_data_vector = listToVector(inputs_list).transpose();
- RealMatrix target_data_vector = listToVector(targets_list).transpose();
- RealMatrix hidden_inputs = input_hiddenMatrix.multiply(input_data_vector); // multiply : vector product
- RealMatrix hidden_outputs = activationFunction(hidden_inputs);
- RealMatrix final_outputs = activationFunction(final_inputs);
- RealMatrix output_errors = target_data_vector.subtract(final_outputs);
- RealMatrix hidden_errors = hidden_outputMatrix.transpose().multiply(output_errors);
- RealMatrix auxM = calculteDescentGradient_ErrorByWeights (output_errors,final_outputs);
- RealMatrix aux2 = auxM.multiply(hidden_outputs.transpose());
- aux2 = aux2.scalarMultiply(learning_rate); // scalar multiply : mulitply each matrix element by a scalar
- hidden_outputMatrix = hidden_outputMatrix.add(aux2);
- auxM = calculteDescentGradient_ErrorByWeights (hidden_errors,hidden_outputs);
- aux2 = auxM.multiply(input_data_vector.transpose());
- aux2 = aux2.scalarMultiply(learning_rate);
- input_hiddenMatrix = input_hiddenMatrix.add(aux2);
- }
- public static void printMatrix(RealMatrix matrix){
- int rows = matrix.getRowDimension(), columns = matrix.getColumnDimension();
- System.out.println("\n\n\n START HERE \n");
- for(int row = 0; row < rows; row++){
- System.out.println("\n");
- for(int column = 0; column < columns; column++){
- System.out.print(matrix.getEntry(row, column) + " | ");
- }
- }
- }
- public RealMatrix applyNeuralNetwork(List inputs_list){
- RealMatrix input_data_matrix = listToVector(inputs_list).transpose();
- RealMatrix hidden_inputs = input_hiddenMatrix.multiply(input_data_matrix);
- RealMatrix hidden_outputs = activationFunction(hidden_inputs);
- RealMatrix final_inputs = hidden_outputMatrix.multiply(hidden_outputs);
- RealMatrix final_outputs = activationFunction(final_inputs);
- return final_outputs;
- }
- private RealMatrix calculteDescentGradient_ErrorByWeights(RealMatrix A, RealMatrix B){
- int i, j, rows = A.getRowDimension(), columns = A.getColumnDimension();
- RealMatrix result = MatrixUtils.createRealMatrix(rows, columns);
- for(i = 0; i < rows; i++){
- for(j = 0; j < columns; j++){
- double a = A.getEntry(i,j), b = B.getEntry(i,j);
- result.setEntry(i, j, a * b * (1.0 - b));
- }
- }
- return result;
- }
- private RealMatrix activationFunction(RealMatrix matriz){
- int i, j, rows = matriz.getRowDimension(), columns = matriz.getColumnDimension();
- RealMatrix result = MatrixUtils.createRealMatrix(rows, columns);
- for(i = 0; i < rows; i++){
- for(j = 0; j < columns; j++){
- result.setEntry(i, j, sigmoidFunction(matriz.getEntry(i,j)));
- }
- }
- return result;
- }
- private double sigmoidFunction(double x) {
- return (1.0 / (1.0 + Math.exp(-x)));
- }
- private RealMatrix listToVector(List list){
- int list_size = list.size();
- RealMatrix vector = MatrixUtils.createRealMatrix(1, list_size);
- for(int column = 0; column < list_size; column++){
- vector.setEntry(0, column, (double)list.get(column));
- }
- return vector;
- }
- private RealMatrix generateMatrixWithRandomWeights (int rows, int columns){
- RealMatrix random_matrix = MatrixUtils.createRealMatrix(rows, columns);
- for(int row = 0; row < rows; row++){
- for(int column = 0; column < columns; column++){
- random_matrix.setEntry (row, column, ThreadLocalRandom.current() .nextDouble(-0.5,0.5));
- }
- }
- return random_matrix;
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement