Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- #include <stdio.h>
- #include <stdint.h>
- #include <math.h>
- typedef double val_t;
- const val_t EPSILON = 1e-8f;
- typedef struct param_t {
- val_t a; val_t b; val_t c;
- } param_t;
- typedef struct adam_t {
- val_t lr; val_t b1; val_t b2;
- uint64_t t;
- param_t m; param_t v;
- } adam_t;
- typedef struct arr_t {val_t *val; size_t len; } arr_t;
- #define ARRAY(...) ((arr_t){ \
- (val_t[]){__VA_ARGS__}, \
- sizeof((val_t[]){__VA_ARGS__}) / sizeof(val_t) \
- })
- #define PARAM_DEFINE_OP(FN_NAME, OP) \
- static param_t FN_NAME(param_t X, param_t Y) { \
- param_t out; \
- out.a = X.a OP Y.a; \
- out.b = X.b OP Y.b; \
- out.c = X.c OP Y.c; \
- return out; \
- } \
- static param_t FN_NAME##_scalar(param_t X, val_t y) { \
- param_t out; \
- out.a = X.a OP y; \
- out.b = X.b OP y; \
- out.c = X.c OP y; \
- return out; \
- }
- PARAM_DEFINE_OP(param_add, +)
- PARAM_DEFINE_OP(param_sub, -)
- PARAM_DEFINE_OP(param_multiply, *)
- PARAM_DEFINE_OP(param_divide, /)
- param_t param_sqrt(param_t X) {
- param_t out;
- out.a = pow(X.a, 0.5);
- out.b = pow(X.b, 0.5);
- out.c = pow(X.c, 0.5);
- return out;
- }
- param_t param_interpolate(param_t X, param_t Y, val_t t) {
- param_t out;
- out.a = t * X.a + (1.0 - t) * Y.a;
- out.b = t * X.b + (1.0 - t) * Y.b;
- out.c = t * X.c + (1.0 - t) * Y.c;
- return out;
- }
- adam_t adam_create(val_t lr, val_t b1, val_t b2) {
- param_t m; param_t v;
- return (adam_t){lr, b1, b2, 0, m, v};
- }
- param_t adam_apply(adam_t* self, param_t g) {
- self->t++;
- self->m = param_interpolate(self->m, g, self->b1);
- self->v = param_interpolate(self->v, param_multiply(g, g), self->b2);
- param_t mhat = param_divide_scalar(self->m, (1-pow(self->b1, self->t)));
- param_t vhat = param_divide_scalar(self->v, (1-pow(self->b2, self->t)));
- mhat = param_multiply_scalar(mhat, self->lr);
- vhat = param_add_scalar(param_sqrt(vhat), EPSILON);
- return param_divide(mhat,vhat);
- }
- param_t forwardback(arr_t X, arr_t Y, param_t theta, val_t* loss) {
- val_t err = 0;
- param_t grad = {0, 0, 0};
- for (int ix = 0; ix < X.len; ix++) {
- val_t x = X.val[ix];
- val_t x2 = x*x;
- val_t dLdyhat = (theta.a*x2 + theta.b*x + theta.c) - Y.val[ix];
- err += dLdyhat*dLdyhat;
- grad = param_add(grad, (param_t){dLdyhat*x2, dLdyhat*x, dLdyhat});
- }
- *loss = err / (X.len << 1);
- return param_divide_scalar(grad, X.len);
- }
- int main() {
- val_t loss; param_t grad;
- param_t theta = {1.2, 3.0, -6.0};
- adam_t adam = adam_create(0.1, 0.9, 0.999);
- arr_t X = ARRAY(0.0, 1.0, 2.0, 3.0);
- arr_t Y = ARRAY(0.9, 0.1, 0.9, 4.1);
- for (int i = 0; i < 1000; i++) {
- grad = forwardback(X, Y, theta, &loss);
- grad = adam_apply(&adam, grad);
- theta = param_sub(theta, grad);
- if (i % 100 == 0) {
- printf("Round: %i, loss: %lf\n", i, loss);
- }
- }
- printf("a: %lf, b: %lf, c: %lf\n", theta.a, theta.b, theta.c);
- }
- /**
- * Round: 0, loss: 20.065000
- * ...
- * Round: 900, loss: 0.004000
- * a: 0.999992, b: -1.959973, c: 0.939987
- **/
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement