Advertisement
Hexadroid

backw

Sep 13th, 2020 (edited)
1,328
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C++ 2.14 KB | None | 0 0
  1. #pragma hdrstop
  2. #pragma argsused
  3.  
  4. #ifdef _WIN32
  5. #include <tchar.h>
  6. #else
  7.   typedef char _TCHAR;
  8.   #define _tmain main
  9. #endif
  10.  
  11. #include <stdio.h>
  12. #include <math.h>
  13.  
  14.  
  15. int i;
  16. const double learning_rate = 0.1;
  17. double node_b_cost_function, node_b_cost_function_derivative, node_b_cost_function_gradient, node_a_cost_function_gradient;
  18. double node_a_input, node_b_input;
  19. double node_a_weight, node_b_weight;
  20. double node_a_output, node_b_output;
  21. double node_b_desired_output;
  22.  
  23. //NETWORK SIMPLE         "node a"  O-----------O  "node b"
  24.  
  25.  
  26. double node_x_function(double x)
  27. {
  28.     return 1 / (1 + exp(-(x)));
  29. }
  30.  
  31. double node_x_function_derivative(double x)
  32. {
  33.      return x * (1 - x);
  34. }
  35.  
  36. int _tmain(int argc, _TCHAR* argv[])
  37. {
  38.  
  39.     node_a_input = 0.02;
  40.     node_a_weight = 0.32; node_b_weight = 0.15;
  41.     node_b_desired_output = 0.005;
  42.  
  43.   for (i; i < 3000000; i++) {
  44.  
  45.  
  46.                         node_a_output = node_x_function(node_a_input*node_a_weight);
  47.  
  48.                         node_b_input = node_a_output;
  49.                         node_b_output = node_x_function(node_b_input*node_b_weight);
  50.  
  51.                         //needs forward propagation in here, comes later
  52.  
  53.                         //
  54.  
  55.  
  56.                         //backward propagation simple test
  57.                         node_b_cost_function = pow(node_x_function(node_b_desired_output) - node_b_output,2);
  58.                         //node_b_cost_function_derivative = 2*(node_x_function(node_b_desired_output) - node_b_output);
  59.                         node_b_cost_function_derivative = 2*(node_b_desired_output - node_b_output);
  60.  
  61.                         node_b_cost_function_gradient =  node_x_function_derivative(node_b_output) * node_b_cost_function_derivative;
  62.                         node_a_cost_function_gradient =  node_x_function_derivative(node_a_output)  * node_b_weight * node_b_cost_function_gradient;
  63.  
  64.                         node_b_weight = node_b_weight + learning_rate * node_b_cost_function_gradient;
  65.                         node_a_weight = node_a_weight + learning_rate * node_a_cost_function_gradient;
  66.  
  67.  
  68.  
  69.                         printf("\nout: %.15lf %.15lf %.15lf %.15lf %.15lf", node_a_input, node_b_input, node_a_output, node_b_output, node_b_desired_output);
  70.                         printf("\nout: %.15lf %.15lf", node_b_cost_function, node_b_cost_function_derivative);
  71.                       }
  72.  
  73.     scanf("%d",&i);
  74.     return 0;
  75. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement