Advertisement
Guest User

Untitled

a guest
Feb 21st, 2019
74
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 4.35 KB | None | 0 0
  1. #include <iostream>
  2. #include "mat.h"
  3.  
  4. #define SLOPE 1
  5. #define ETA 0.1
  6.  
  7. using namespace std;
  8.  
  9. double transferFunction(double d);
  10. double transferFunctionDerrivative(double d);
  11. double stepFunction(double d);
  12. Matrix addBias(Matrix m);
  13.  
  14. class NeuronLayer {
  15. public:
  16. NeuronLayer(int numNeurons, int numInputs) {
  17. matrix = new Matrix(numInputs, numNeurons);
  18. matrix->rand(0.0, 1.0);
  19. }
  20.  
  21. Matrix *matrix;
  22. };
  23.  
  24. class NeuralNetwork {
  25. public:
  26. NeuronLayer *l1;
  27. NeuronLayer *l2;
  28.  
  29. void train(Matrix input, Matrix target) {
  30. bool running = true;
  31. int count = 0;
  32. double sum = 0;
  33. while (running) {
  34. Matrix l1Out = input.dot(l1->matrix);
  35. l1Out.map(transferFunction);
  36. l1Out = addBias(l1Out);
  37. Matrix l2Out = l1Out.dot(l2->matrix);
  38. l2Out.map(transferFunction);
  39.  
  40. Matrix l2Error = new Matrix(target);
  41. l2Error.sub(l2Out);
  42. Matrix l2D = new Matrix(l2Out);
  43. l2D.map(transferFunctionDerrivative);
  44. Matrix l2Delta = l2Error.mul(l2D);
  45.  
  46. Matrix l1Error = l2Delta.dotT(l2->matrix);
  47. Matrix l1D = new Matrix(l1Out);
  48. l1D.map(transferFunctionDerrivative);
  49. Matrix l1Delta = l1Error.mul(l1D);
  50.  
  51. Matrix l1Adjustment = input.Tdot(l1Delta);
  52. l1Adjustment = l1Adjustment.extract(0, 0, l1Adjustment.numRows(), l1Adjustment.numCols() - 1);
  53. Matrix l2Adjustment = l1Out.Tdot(l2Delta);
  54. l1Adjustment.scalarMul(ETA);
  55. l2Adjustment.scalarMul(ETA);
  56. l1->matrix->add(l1Adjustment);
  57. l2->matrix->add(l2Adjustment);
  58. // sum error
  59. sum = 0;
  60. Matrix error = new Matrix(target);
  61. error.sub(l2Out);
  62. for (int r = 0; r < error.numRows(); r++) {
  63. for (int c = 0; c < error.numCols(); c++) {
  64. sum += error.get(r,c) * error.get(r, c);
  65. }
  66. }
  67. running = sum > 0.01 && count++ < 100000;
  68. }
  69.  
  70. }
  71.  
  72. Matrix think(Matrix input) {
  73. Matrix l1Out = input.dot(l1->matrix);
  74. l1Out.map(transferFunction);
  75. l1Out = addBias(l1Out);
  76. Matrix l2Out = l1Out.dot(l2->matrix);
  77. l2Out.map(transferFunction);
  78. return l2Out;
  79. }
  80. };
  81.  
  82. int main() {
  83. initRand();
  84. int numInputs, numHiddenNodes, numClasses, numOutputs;
  85. scanf("%d %d %d", &numInputs, &numHiddenNodes, &numClasses);
  86. Matrix raw;
  87. raw.read();
  88. numOutputs = raw.numCols() - numInputs;
  89.  
  90. Matrix input = raw.extract(0,0,raw.numRows(), numInputs);
  91. input.setName("input");
  92. input = addBias(input);
  93.  
  94. Matrix target = raw.extract(0, numInputs, 0, 0);
  95. target.setName("target");
  96.  
  97. auto nn = new NeuralNetwork();
  98. nn->l1 = new NeuronLayer(numHiddenNodes, numInputs + 1);
  99. nn->l2 = new NeuronLayer(numOutputs, numHiddenNodes + 1);
  100. nn->train(input, target);
  101.  
  102. raw.read();
  103. Matrix testInput = raw.extract(0, 0, raw.numRows(), numInputs);
  104. testInput = addBias(testInput);
  105. Matrix testTarget = raw.extract(0, numInputs, 0, 0);
  106.  
  107. Matrix testOutput = nn->think(testInput);
  108. testOutput.map(stepFunction);
  109. target.printfmt("Target", "%.4f ", false);
  110. testOutput.printfmt("Predicted", "%.4f ", false);
  111. Matrix confusionMatrix = new Matrix(2, 2);
  112. confusionMatrix.map([](double d) -> double {return 0;});
  113. for (int r = 0; r < testTarget.numRows(); r++) {
  114. for (int c = 0; c < testTarget.numCols(); c++) {
  115. double actual = testTarget.get(r, c);
  116. double predicted = testOutput.get(r, c);
  117. int row = (int) actual == 1 ? 1 : 0;
  118. int col = (int) predicted == 1 ? 1 : 0;
  119. confusionMatrix.inc(row, col);
  120. }
  121. }
  122. confusionMatrix.printfmt("Confusion Matrix", "%.4f ", false);
  123. }
  124.  
  125. double transferFunction(double d) {
  126. return 1.0/(1.0 + exp(-SLOPE*d));
  127. }
  128.  
  129. double transferFunctionDerrivative(double d) {
  130. return SLOPE * d * (1 - d);
  131. }
  132.  
  133. double stepFunction(double d) {
  134. return d >= 0.5 ? 1.0 : 0.0;
  135. }
  136.  
  137. Matrix addBias(Matrix m) {
  138. Matrix matrix = new Matrix(m.numRows(), m.numCols() + 1);
  139. matrix.setName(m.getName());
  140. for (int r = 0; r < m.numRows(); r++) {
  141. for (int c = 0; c < m.numCols(); c++) {
  142. matrix.set(r, c, m.get(r, c));
  143. }
  144. }
  145. matrix.mapCol(m.numCols(), [](double d) -> double {return 1.0;});
  146. return matrix;
  147. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement