Advertisement
Guest User

Untitled

a guest
Apr 30th, 2016
68
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.28 KB | None | 0 0
  1. package NN;
  2.  
  3. import java.util.Random;
  4.  
  5. public class NN {
  6. private int inputCount, hiddenCount, outputCount;
  7. private double learningRate, lowerRange;
  8. private double[][] weight1, weight2;
  9. private double[] hiddenError, outputError, input, hidden, output;
  10. private boolean bias = false;
  11.  
  12. public NN(int inputCount, int hiddenCount, int outputCount, double learningRate, boolean bias, int lowerRange) {
  13. this.bias = bias;
  14. this.learningRate = learningRate;
  15. this.inputCount = inputCount;
  16. this.hiddenCount = hiddenCount;
  17. this.outputCount = outputCount;
  18. this.lowerRange = lowerRange;
  19. if (bias) {
  20. weight1 = new double[inputCount + 1][hiddenCount];// +1 for bias nodes
  21. weight2 = new double[hiddenCount + 1][outputCount];
  22. } else {
  23. weight1 = new double[inputCount][hiddenCount];
  24. weight2 = new double[hiddenCount][outputCount];
  25. }
  26. hiddenError = new double[hiddenCount];
  27. outputError = new double[outputCount];
  28. hidden = new double[hiddenCount];
  29. output = new double[outputCount];
  30. randomise();
  31. }
  32.  
  33. public void forwardPass(double[] inputs, boolean print) {
  34. this.input = inputs;
  35. double temp = 0;
  36. for (int hiddenIndex = 0; hiddenIndex < hidden.length; hiddenIndex++) {
  37. temp = 0;
  38. for (int inputIndex = 0; inputIndex < inputs.length; inputIndex++)
  39. temp += inputs[inputIndex] * weight1[inputIndex][hiddenIndex];
  40. if (bias)
  41. temp += weight1[inputCount][hiddenIndex];
  42. hidden[hiddenIndex] = activationFunction(temp);
  43.  
  44. }
  45. for (int outputIndex = 0; outputIndex < output.length; outputIndex++) {
  46. temp = 0;
  47. for (int hiddenIndex = 0; hiddenIndex < hidden.length; hiddenIndex++)
  48. temp += hidden[hiddenIndex] * weight2[hiddenIndex][outputIndex];
  49. if (bias)
  50. temp += weight2[hiddenCount][outputIndex];
  51. output[outputIndex] = activationFunction(temp);
  52. }
  53. }
  54.  
  55. public void backpropagation(double[] input, double[] desiredOutput) {
  56. forwardPass(input, false);
  57. for (int outIndex = 0; outIndex < outputCount; outIndex++)
  58. outputError[outIndex] = (desiredOutput[outIndex] - this.output[outIndex]) * activationFunctionDerivitive(this.output[outIndex]);
  59. for (int hIndex = 0; hIndex < hiddenCount; hIndex++) {
  60. hiddenError[hIndex] = 0;
  61. for (int outIndex = 0; outIndex < outputCount; outIndex++)
  62. hiddenError[hIndex] += outputError[outIndex] * weight2[hIndex][outIndex];
  63. hiddenError[hIndex] = activationFunctionDerivitive(hidden[hIndex]) * hiddenError[hIndex];
  64. }
  65. for (int outIndex = 0; outIndex < outputCount; outIndex++) { // update the hidden - output weights
  66. for (int hIndex = 0; hIndex < hiddenCount; hIndex++)
  67. weight2[hIndex][outIndex] += learningRate * outputError[outIndex] * hidden[hIndex];
  68. if (bias)
  69. weight2[hiddenCount][outIndex] += learningRate * outputError[outIndex];
  70. }
  71. for (int hIndex = 0; hIndex < hiddenCount; hIndex++) { // update the input - hidden weights
  72. for (int inIndex = 0; inIndex < inputCount; inIndex++)
  73. weight1[inIndex][hIndex] += learningRate * hiddenError[hIndex] * input[inIndex];
  74. if (bias)
  75. weight1[inputCount][hIndex] += learningRate * hiddenError[hIndex];
  76. }
  77. }
  78.  
  79. private void randomise() {
  80. Random rand = new Random();
  81. for (int index = 0; index < weight1.length; index++)
  82. for (int index2 = 0; index2 < weight1[index].length; index2++)
  83. weight1[index][index2] = (2 * rand.nextDouble() - 1) / 10;
  84. for (int index = 0; index < weight2.length; index++)
  85. for (int index2 = 0; index2 < weight2[index].length; index2++)
  86. weight2[index][index2] = (2 * rand.nextDouble() - 1) / 10;
  87. if (bias) {
  88. for (int h = 0; h < hiddenCount; h++)
  89. weight1[inputCount][h] = 1;
  90. for (int o = 0; o < outputCount; o++)
  91. weight2[hiddenCount][o] = 1;
  92. }
  93. }
  94.  
  95. public void setLearningRate(double rate) {
  96. this.learningRate = rate;
  97. }
  98.  
  99. public double getLearningRate() {
  100. return learningRate;
  101. }
  102.  
  103. public double[] getOutput() {
  104. return output;
  105. }
  106.  
  107. private double activationFunction(double d) {
  108. if (lowerRange == -1)
  109. return (2.0 / (1.0 + Math.exp(-2 * d))) - 1;
  110. else
  111. return 1.0 / (1.0 + Math.exp(-d));
  112. }
  113.  
  114. private double activationFunctionDerivitive(double d) {
  115. if (lowerRange == -1)
  116. return 1 - Math.pow(activationFunction(d), 2);
  117. else
  118. return d * (1 - d);
  119. }
  120. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement