hedgefund

c_autograd_02

Jan 16th, 2025
42
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C 5.12 KB | Source Code | 0 0
  1. #include <stddef.h>
  2. #include <stdio.h>
  3. #include <stdlib.h>
  4. #include <math.h>
  5. #include <string.h>
  6. #include <time.h>
  7.  
  8. #define MAX_VARS 1000
  9. #define MAX_OPS 1000
  10. #define MAX_INPUTS 2
  11. #define MAX_OUTPUTS 1
  12.  
  13. typedef struct {
  14.     double val;
  15.     double grad;
  16.     int id;
  17.     int is_active;
  18. } NaiveVar;
  19.  
  20. typedef enum {
  21.     OP_SUM,
  22.     OP_PROD,
  23.     OP_SOFTPLUS
  24. } OpType;
  25.  
  26. typedef struct {
  27.     OpType type;
  28.     int n_inputs;
  29.     int input_ids[MAX_INPUTS];
  30.     int output_ids[MAX_OUTPUTS];
  31. } MyOperation;
  32.  
  33. typedef struct {
  34.     NaiveVar vars[MAX_VARS];
  35.     int var_count;
  36.     MyOperation ops[MAX_OPS];
  37.     int op_count;
  38. } NaiveTape;
  39.  
  40. NaiveTape create_tape() {
  41.     NaiveTape tape = {
  42.         .var_count = 0,
  43.         .op_count = 0
  44.     };
  45.     return tape;
  46. }
  47.  
  48. int create_var(NaiveTape* tape, double val)
  49. {
  50.     if (tape->var_count >= MAX_VARS) {
  51.         fprintf(stderr, "Maximum variable limit reached\n");
  52.         exit(1);
  53.     }
  54.    
  55.     int id = tape->var_count++;
  56.     tape->vars[id] = (NaiveVar){
  57.         .val = val,
  58.         .grad = 0.0,
  59.         .id = id,
  60.         .is_active = 1
  61.     };
  62.     return id;
  63. }
  64.  
  65. void add_operation(NaiveTape* tape, MyOperation op)
  66. {
  67.     if (tape->op_count >= MAX_OPS) {
  68.         fprintf(stderr, "Maximum operation limit reached\n");
  69.         exit(1);
  70.     }
  71.     tape->ops[tape->op_count++] = op;
  72. }
  73.  
  74. int tape_sum(NaiveTape* tape, int* var_ids, int n_vars)
  75. {
  76.     double sum_val = 0.0;
  77.     for (int i = 0; i < n_vars; i++) {
  78.         sum_val += tape->vars[var_ids[i]].val;
  79.     }
  80.    
  81.     int result_id = create_var(tape, sum_val);
  82.    
  83.     MyOperation op = {
  84.         .type = OP_SUM,
  85.         .n_inputs = n_vars
  86.     };
  87.    
  88.     memcpy(op.input_ids, var_ids, n_vars * sizeof(int));
  89.     op.output_ids[0] = result_id;
  90.    
  91.     add_operation(tape, op);
  92.     return result_id;
  93. }
  94.  
  95. int tape_prod(NaiveTape* tape, int var1_id, int var2_id)
  96. {
  97.     double prod_val = tape->vars[var1_id].val * tape->vars[var2_id].val;
  98.     int result_id = create_var(tape, prod_val);
  99.    
  100.     MyOperation op = {
  101.         .type = OP_PROD,
  102.         .n_inputs = 2,
  103.         .input_ids = {var1_id, var2_id},
  104.         .output_ids = {result_id}
  105.     };
  106.    
  107.     add_operation(tape, op);
  108.     return result_id;
  109. }
  110.  
  111. int tape_softplus(NaiveTape* tape, int var_id)
  112. {
  113.     double softplus_val = log1p(exp(tape->vars[var_id].val));
  114.     int result_id = create_var(tape, softplus_val);
  115.    
  116.     MyOperation op = {
  117.         .type = OP_SOFTPLUS,
  118.         .n_inputs = 1,
  119.         .input_ids = {var_id},
  120.         .output_ids = {result_id}
  121.     };
  122.    
  123.     add_operation(tape, op);
  124.     return result_id;
  125. }
  126.  
  127. void tape_backward(NaiveTape* tape, int var_id)
  128. {
  129.     tape->vars[var_id].grad = 1.0;
  130.    
  131.     for (int i = tape->op_count - 1; i >= 0; i--) {
  132.         MyOperation* op = &tape->ops[i];
  133.        
  134.         switch (op->type) {
  135.             case OP_SUM: {
  136.                 double grad = tape->vars[op->output_ids[0]].grad;
  137.                 for (int j = 0; j < op->n_inputs; j++) {
  138.                     tape->vars[op->input_ids[j]].grad += grad;
  139.                 }
  140.                 break;
  141.             }
  142.             case OP_PROD: {
  143.                 double grad = tape->vars[op->output_ids[0]].grad;
  144.                 NaiveVar* in1 = &tape->vars[op->input_ids[0]];
  145.                 NaiveVar* in2 = &tape->vars[op->input_ids[1]];
  146.                 in1->grad += in2->val * grad;
  147.                 in2->grad += in1->val * grad;
  148.                 break;
  149.             }
  150.             case OP_SOFTPLUS: {
  151.                 int input_id = op->input_ids[0];
  152.                 double exp_val = exp(-tape->vars[input_id].val);
  153.                 tape->vars[input_id].grad +=
  154.                     tape->vars[op->output_ids[0]].grad /
  155.                     (1.0 + exp_val);
  156.                 break;
  157.             }
  158.         }
  159.     }
  160. }
  161.  
  162. int main()
  163. {
  164.     clock_t start, end;
  165.     double elapsed_time;
  166.     size_t iterations = 1000000;
  167.     srand(time(NULL));
  168.  
  169.     start = clock();
  170.     for (size_t iter = 0; iter < iterations; iter++) {
  171.         NaiveTape tape = create_tape();
  172.        
  173.         int x_id = create_var(&tape, (float)rand() / RAND_MAX);
  174.         int y_id = create_var(&tape, (float)rand() / RAND_MAX);
  175.        
  176.         int var_ids[] = {x_id, y_id};
  177.         int sum_xy_id = tape_sum(&tape, var_ids, 2);
  178.         int prod_sum_xy_id = tape_prod(&tape, sum_xy_id, sum_xy_id);
  179.         int softplus_prod_id = tape_softplus(&tape, prod_sum_xy_id);
  180.        
  181.         tape_backward(&tape, softplus_prod_id);
  182.        
  183.         if (iter == iterations - 1) {
  184.             printf("sum_xy->val: %f\n", tape.vars[sum_xy_id].val);
  185.             printf("prod_sum_xy->val: %f\n", tape.vars[prod_sum_xy_id].val);
  186.             printf("softplus_prod->val: %f\n", tape.vars[softplus_prod_id].val);
  187.             printf("sum_xy->grad: %f\n", tape.vars[sum_xy_id].grad);
  188.             printf("x->grad: %f\n", tape.vars[x_id].grad);
  189.             printf("y->grad: %f\n", tape.vars[y_id].grad);
  190.         }
  191.     }
  192.    
  193.     end = clock();
  194.     elapsed_time = ((double) (end - start)) * 1000 / CLOCKS_PER_SEC;
  195.     printf("\nElapsed time: %f ms\n", elapsed_time);
  196.  
  197.     return 0;
  198. }
  199.  
Advertisement
Add Comment
Please, Sign In to add comment