Advertisement
yohoburner

C Gradient Descent

May 12th, 2024
313
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C 3.17 KB | None | 0 0
  1. #include <stdio.h>
  2. #include <stdint.h>
  3. #include <math.h>
  4.  
  5. typedef double val_t;
  6.  
  7. const val_t EPSILON = 1e-8f;
  8.  
  9. typedef struct param_t {
  10.     val_t a; val_t b; val_t c;
  11. } param_t;
  12.  
  13. typedef struct adam_t {
  14.     val_t lr; val_t b1; val_t b2;
  15.     uint64_t t;
  16.     param_t m; param_t v;
  17. } adam_t;
  18.  
  19. typedef struct arr_t {val_t *val; size_t len; } arr_t;
  20.  
  21. #define ARRAY(...) ((arr_t){ \
  22.     (val_t[]){__VA_ARGS__}, \
  23.     sizeof((val_t[]){__VA_ARGS__}) / sizeof(val_t) \
  24. })
  25.  
  26. #define PARAM_DEFINE_OP(FN_NAME, OP) \
  27.     static param_t FN_NAME(param_t X, param_t Y) { \
  28.         param_t out; \
  29.         out.a = X.a OP Y.a; \
  30.         out.b = X.b OP Y.b; \
  31.         out.c = X.c OP Y.c; \
  32.         return out; \
  33.     } \
  34.     static param_t FN_NAME##_scalar(param_t X, val_t y) { \
  35.         param_t out; \
  36.         out.a = X.a OP y; \
  37.         out.b = X.b OP y; \
  38.         out.c = X.c OP y; \
  39.         return out; \
  40.     }
  41.  
  42. PARAM_DEFINE_OP(param_add, +)
  43. PARAM_DEFINE_OP(param_sub, -)
  44. PARAM_DEFINE_OP(param_multiply, *)
  45. PARAM_DEFINE_OP(param_divide, /)
  46.  
  47. param_t param_sqrt(param_t X) {
  48.     param_t out;
  49.     out.a = pow(X.a, 0.5);
  50.     out.b = pow(X.b, 0.5);
  51.     out.c = pow(X.c, 0.5);
  52.     return out;
  53. }
  54.  
  55. param_t param_interpolate(param_t X, param_t Y, val_t t) {
  56.     param_t out;
  57.     out.a = t * X.a + (1.0 - t) * Y.a;
  58.     out.b = t * X.b + (1.0 - t) * Y.b;
  59.     out.c = t * X.c + (1.0 - t) * Y.c;
  60.     return out;
  61. }
  62.  
  63. adam_t adam_create(val_t lr, val_t b1, val_t b2) {
  64.     param_t m; param_t v;
  65.     return (adam_t){lr, b1, b2, 0, m, v};
  66. }
  67.  
  68. param_t adam_apply(adam_t* self, param_t g) {
  69.     self->t++;
  70.     self->m = param_interpolate(self->m, g, self->b1);
  71.     self->v = param_interpolate(self->v, param_multiply(g, g), self->b2);
  72.     param_t mhat = param_divide_scalar(self->m, (1-pow(self->b1, self->t)));
  73.     param_t vhat = param_divide_scalar(self->v, (1-pow(self->b2, self->t)));
  74.     mhat = param_multiply_scalar(mhat, self->lr);
  75.     vhat = param_add_scalar(param_sqrt(vhat), EPSILON);
  76.     return param_divide(mhat,vhat);
  77. }
  78.  
  79. param_t forwardback(arr_t X, arr_t Y, param_t theta, val_t* loss) {
  80.     val_t err = 0;
  81.     param_t grad = {0, 0, 0};
  82.     for (int ix = 0; ix < X.len; ix++) {
  83.         val_t x = X.val[ix];
  84.         val_t x2 = x*x;
  85.         val_t dLdyhat = (theta.a*x2 + theta.b*x + theta.c) - Y.val[ix];
  86.         err += dLdyhat*dLdyhat;
  87.         grad = param_add(grad, (param_t){dLdyhat*x2, dLdyhat*x, dLdyhat});
  88.     }
  89.     *loss = err / (X.len << 1);
  90.     return param_divide_scalar(grad, X.len);
  91. }
  92.  
  93. int main() {
  94.     val_t loss; param_t grad;
  95.     param_t theta = {1.2, 3.0, -6.0};
  96.     adam_t adam = adam_create(0.1, 0.9, 0.999);
  97.  
  98.     arr_t X = ARRAY(0.0, 1.0, 2.0, 3.0);
  99.     arr_t Y = ARRAY(0.9, 0.1, 0.9, 4.1);
  100.  
  101.     for (int i = 0; i < 1000; i++) {
  102.         grad = forwardback(X, Y, theta, &loss);
  103.         grad = adam_apply(&adam, grad);
  104.         theta = param_sub(theta, grad);
  105.         if (i % 100 == 0) {
  106.             printf("Round: %i, loss: %lf\n", i, loss);
  107.         }
  108.     }
  109.     printf("a: %lf, b: %lf, c: %lf\n", theta.a, theta.b, theta.c);
  110. }
  111. /**
  112.  * Round: 0, loss: 20.065000
  113.  * ...
  114.  * Round: 900, loss: 0.004000
  115.  * a: 0.999992, b: -1.959973, c: 0.939987
  116. **/
  117.  
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement