Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- package NN;
- import java.util.Random;
- public class NN {
- private int inputCount, hiddenCount, outputCount;
- private double learningRate, lowerRange;
- private double[][] weight1, weight2;
- private double[] hiddenError, outputError, input, hidden, output;
- private boolean bias = false;
- public NN(int inputCount, int hiddenCount, int outputCount, double learningRate, boolean bias, int lowerRange) {
- this.bias = bias;
- this.learningRate = learningRate;
- this.inputCount = inputCount;
- this.hiddenCount = hiddenCount;
- this.outputCount = outputCount;
- this.lowerRange = lowerRange;
- if (bias) {
- weight1 = new double[inputCount + 1][hiddenCount];// +1 for bias nodes
- weight2 = new double[hiddenCount + 1][outputCount];
- } else {
- weight1 = new double[inputCount][hiddenCount];
- weight2 = new double[hiddenCount][outputCount];
- }
- hiddenError = new double[hiddenCount];
- outputError = new double[outputCount];
- hidden = new double[hiddenCount];
- output = new double[outputCount];
- randomise();
- }
- public void forwardPass(double[] inputs, boolean print) {
- this.input = inputs;
- double temp = 0;
- for (int hiddenIndex = 0; hiddenIndex < hidden.length; hiddenIndex++) {
- temp = 0;
- for (int inputIndex = 0; inputIndex < inputs.length; inputIndex++)
- temp += inputs[inputIndex] * weight1[inputIndex][hiddenIndex];
- if (bias)
- temp += weight1[inputCount][hiddenIndex];
- hidden[hiddenIndex] = activationFunction(temp);
- }
- for (int outputIndex = 0; outputIndex < output.length; outputIndex++) {
- temp = 0;
- for (int hiddenIndex = 0; hiddenIndex < hidden.length; hiddenIndex++)
- temp += hidden[hiddenIndex] * weight2[hiddenIndex][outputIndex];
- if (bias)
- temp += weight2[hiddenCount][outputIndex];
- output[outputIndex] = activationFunction(temp);
- }
- }
- public void backpropagation(double[] input, double[] desiredOutput) {
- forwardPass(input, false);
- for (int outIndex = 0; outIndex < outputCount; outIndex++)
- outputError[outIndex] = (desiredOutput[outIndex] - this.output[outIndex]) * activationFunctionDerivitive(this.output[outIndex]);
- for (int hIndex = 0; hIndex < hiddenCount; hIndex++) {
- hiddenError[hIndex] = 0;
- for (int outIndex = 0; outIndex < outputCount; outIndex++)
- hiddenError[hIndex] += outputError[outIndex] * weight2[hIndex][outIndex];
- hiddenError[hIndex] = activationFunctionDerivitive(hidden[hIndex]) * hiddenError[hIndex];
- }
- for (int outIndex = 0; outIndex < outputCount; outIndex++) { // update the hidden - output weights
- for (int hIndex = 0; hIndex < hiddenCount; hIndex++)
- weight2[hIndex][outIndex] += learningRate * outputError[outIndex] * hidden[hIndex];
- if (bias)
- weight2[hiddenCount][outIndex] += learningRate * outputError[outIndex];
- }
- for (int hIndex = 0; hIndex < hiddenCount; hIndex++) { // update the input - hidden weights
- for (int inIndex = 0; inIndex < inputCount; inIndex++)
- weight1[inIndex][hIndex] += learningRate * hiddenError[hIndex] * input[inIndex];
- if (bias)
- weight1[inputCount][hIndex] += learningRate * hiddenError[hIndex];
- }
- }
- private void randomise() {
- Random rand = new Random();
- for (int index = 0; index < weight1.length; index++)
- for (int index2 = 0; index2 < weight1[index].length; index2++)
- weight1[index][index2] = (2 * rand.nextDouble() - 1) / 10;
- for (int index = 0; index < weight2.length; index++)
- for (int index2 = 0; index2 < weight2[index].length; index2++)
- weight2[index][index2] = (2 * rand.nextDouble() - 1) / 10;
- if (bias) {
- for (int h = 0; h < hiddenCount; h++)
- weight1[inputCount][h] = 1;
- for (int o = 0; o < outputCount; o++)
- weight2[hiddenCount][o] = 1;
- }
- }
- public void setLearningRate(double rate) {
- this.learningRate = rate;
- }
- public double getLearningRate() {
- return learningRate;
- }
- public double[] getOutput() {
- return output;
- }
- private double activationFunction(double d) {
- if (lowerRange == -1)
- return (2.0 / (1.0 + Math.exp(-2 * d))) - 1;
- else
- return 1.0 / (1.0 + Math.exp(-d));
- }
- private double activationFunctionDerivitive(double d) {
- if (lowerRange == -1)
- return 1 - Math.pow(activationFunction(d), 2);
- else
- return d * (1 - d);
- }
- }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement