Advertisement
Guest User

nn.c

a guest
Feb 18th, 2018
66
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C 2.10 KB | None | 0 0
  1. #include "nn.h"
  2.  
  3. static float sigmoid(float a) {
  4.     return (0.5*a) / (1 + fabs(a)) + 0.5;
  5. }
  6.  
  7.  
  8. void nn_process(nn *net) {
  9.     float hidden1[NN_NUMHIDDEN];
  10.     float hidden2[NN_NUMHIDDEN];
  11.     int w = 0;
  12.     int i, h, o;
  13.  
  14.     /* hidden layer 1 */
  15.     for (h = 0; h < NN_NUMHIDDEN; h++) {
  16.         float activation = 0;
  17.         for (i = 0; i < NN_NUMINPUTS; i++)
  18.             activation += net->weight[w++] * net->input[i];
  19.  
  20.         activation += net->weight[w++];
  21.         hidden1[h] = sigmoid(activation);
  22.     }
  23.  
  24.     /* hidden layer 2 */
  25.     for (h = 0; h < NN_NUMHIDDEN; h++) {
  26.         float activation = 0;
  27.         for (i = 0; i < NN_NUMHIDDEN; i++)
  28.             activation += net->weight[w++] * hidden1[i];
  29.  
  30.         activation += net->weight[w++];
  31.         hidden2[h] = sigmoid(activation);
  32.     }
  33.  
  34.     /* output layer */
  35.     for (o = 0; o < NN_NUMOUTPUTS; o++) {
  36.         float activation = 0;
  37.         for (h = 0; h < NN_NUMHIDDEN; h++)
  38.             activation += net->weight[w++] * hidden2[h];
  39.  
  40.         activation += net->weight[w++];
  41.         net->output[o] = sigmoid(activation);
  42.     }
  43. }
  44.  
  45.  
  46. void nn_fillrandom(nn *net) {
  47.     int w;
  48.     for (w = 0; w < NN_NUMWEIGHTS; w++)
  49.         net->weight[w] = RANDDBL(-8.0, 8.0);
  50. }
  51.  
  52. void nn_mutate(nn *net) {
  53.     if (DICE(10)) {
  54.         /* high mutation */
  55.         int w;
  56.         for (w = 0; w < NN_NUMWEIGHTS; w++)
  57.             if (DICE(200)) net->weight[w] += RANDDBL(-8, 8);
  58.     }
  59.  
  60.     else {
  61.         int w;
  62.         for (w = 0; w < NN_NUMWEIGHTS; w++) {
  63.             if (DICE(500)) net->weight[w] += RANDDBL(-1, 1);
  64.             if (DICE(2000)) net->weight[w] += RANDDBL(-8, 8);
  65.         }
  66.     }
  67.  
  68.     /* clamp */
  69.     int w;
  70.     for (w = 0; w < NN_NUMWEIGHTS; w++)
  71.         net->weight[w] = RETMAX(-10, RETMIN(10, net->weight[w]));
  72. }
  73.  
  74. void nn_copy(nn *dest, nn *src) {
  75.     int w;
  76.     for (w = 0; w < NN_NUMWEIGHTS; w++)
  77.         dest->weight[w] = src->weight[w];
  78. }
  79.  
  80. void nn_print(nn *net) {
  81.     int i, h, o;
  82.     int w = 0;
  83.     for (h = 0; h < NN_NUMHIDDEN; h++) {
  84.         for (i = 0; i < NN_NUMINPUTS; i++)
  85.             printf("%2.2f ", net->weight[w++]);
  86.            
  87.         printf("%2.2f\n", net->weight[w++]);
  88.     }
  89.  
  90.     printf("\n\n");
  91.  
  92.     for (o = 0; o < NN_NUMOUTPUTS; o++) {
  93.         for (h = 0; h < NN_NUMHIDDEN; h++)
  94.             printf("%2.2f ", net->weight[w++]);
  95.            
  96.         printf("%2.2f\n", net->weight[w++]);
  97.     }
  98.  
  99.     printf("\n\n");
  100. }
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement