hedgefund

zig_autograd_02a

Jan 16th, 2025
45
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
C 5.00 KB | Source Code | 0 0
  1. const std = @import("std");
  2. const testing = std.testing;
  3. const print = std.debug.print;
  4. const tMilli = std.time.milliTimestamp;
  5. const rand = std.rand;
  6.  
  7. const MAX_VARS = 1000;
  8. const MAX_OPS = 1000;
  9. const MAX_INPUTS = 2;
  10.  
  11. const OpType = enum(u2) {
  12.     Sum,
  13.     Prod,
  14.     Softplus,
  15. };
  16.  
  17. const NaiveVar = struct {
  18.     val: f64,
  19.     grad: f64,
  20.     id: usize,
  21.     is_active: bool,
  22. };
  23.  
  24. const Operation = struct {
  25.     op_type: OpType,
  26.     input_ids: [MAX_INPUTS]usize,
  27.     output_id: usize,
  28. };
  29.  
  30. const NaiveTape = struct {
  31.     vars: [MAX_VARS]NaiveVar,
  32.     var_count: usize,
  33.     ops: [MAX_OPS]Operation,
  34.     op_count: usize,
  35.  
  36.     fn init() NaiveTape {
  37.         return .{
  38.             .vars = undefined,
  39.             .var_count = 0,
  40.             .ops = undefined,
  41.             .op_count = 0,
  42.         };
  43.     }
  44.  
  45.     fn createVar(self: *NaiveTape, val: f64) usize {
  46.         const id = self.var_count;
  47.         self.vars[id] = NaiveVar {
  48.             .val = val,
  49.             .grad = 0.0,
  50.             .id = id,
  51.             .is_active = true,
  52.         };
  53.         self.var_count += 1;
  54.         return id;
  55.     }
  56.  
  57.     fn addOperation(self: *NaiveTape, op: Operation) void {
  58.         self.ops[self.op_count] = op;
  59.         self.op_count += 1;
  60.     }
  61.  
  62.     fn sum(self: *NaiveTape, input1_id: usize, input2_id: usize) usize {
  63.         const sum_val = self.vars[input1_id].val + self.vars[input2_id].val;
  64.         const result_id = self.createVar(sum_val);
  65.         const op = Operation {
  66.             .op_type = .Sum,
  67.             .input_ids = .{ input1_id, input2_id },
  68.             .output_id = result_id,
  69.         };
  70.         self.addOperation(op);
  71.         return result_id;
  72.     }
  73.  
  74.     fn prod(self: *NaiveTape, var1_id: usize, var2_id: usize) usize {
  75.         const prod_val = self.vars[var1_id].val * self.vars[var2_id].val;
  76.         const result_id = self.createVar(prod_val);
  77.         const op = Operation{
  78.             .op_type = .Prod,
  79.             .input_ids = .{ var1_id, var2_id },
  80.             .output_id = result_id,
  81.         };
  82.         self.addOperation(op);
  83.         return result_id;
  84.     }
  85.  
  86.     fn softplus(self: *NaiveTape, nvar_id: usize) usize {
  87.         const softplus_val = std.math.log1p(std.math.exp(self.vars[nvar_id].val));
  88.         const result_id = self.createVar(softplus_val);
  89.         const op = Operation{
  90.             .op_type = .Softplus,
  91.             .input_ids = .{ nvar_id, 0},
  92.             .output_id = result_id,
  93.         };
  94.         self.addOperation(op);
  95.         return result_id;
  96.     }
  97.  
  98.     fn backward(self: *NaiveTape, var_id: usize) void {
  99.         self.vars[var_id].grad = 1.0;
  100.         var i = self.op_count;
  101.         while (i > 0) {
  102.             i -= 1;
  103.             const op = self.ops[i];
  104.             const output_grad = self.vars[op.output_id].grad;
  105.  
  106.             switch (op.op_type) {
  107.                 .Sum => {
  108.                     self.vars[op.input_ids[0]].grad += output_grad;
  109.                     self.vars[op.input_ids[1]].grad += output_grad;
  110.                 },
  111.                 .Prod => {
  112.                     const input1_val = self.vars[op.input_ids[0]].val;
  113.                     const input2_val = self.vars[op.input_ids[1]].val;
  114.                     self.vars[op.input_ids[0]].grad += input2_val * output_grad;
  115.                     self.vars[op.input_ids[1]].grad += input1_val * output_grad;
  116.                 },
  117.                 .Softplus => {
  118.                     const input_val = self.vars[op.input_ids[0]].val;
  119.                     const exp_val = std.math.exp(-input_val);
  120.                     self.vars[op.input_ids[0]].grad += output_grad / (1.0 + exp_val);
  121.                 },
  122.             }
  123.         }
  124.     }
  125. };
  126.  
  127.  
  128. pub fn main() !void {
  129.     const iterations: usize = 1000000;
  130.     var rng = rand.DefaultPrng.init(@as(u64, @bitCast(tMilli())));
  131.  
  132.     const start_time = tMilli();
  133.     var i: usize = 0;
  134.  
  135.     while (i < iterations) : (i += 1) {
  136.         var tape = NaiveTape.init();
  137.  
  138.         // Reset the variables to 0
  139.         tape.var_count = 0;
  140.         tape.op_count = 0;
  141.  
  142.         const var1_id = tape.createVar(rng.random().float(f64));
  143.         const var2_id = tape.createVar(rng.random().float(f64));
  144.  
  145.         const sum_var_id = tape.sum(var1_id, var2_id);
  146.         const prod_var_id = tape.prod(sum_var_id, sum_var_id);
  147.         const softplus_var_id = tape.softplus(prod_var_id);
  148.  
  149.         tape.backward(softplus_var_id);
  150.  
  151.  
  152.         if (i == iterations - 1) {
  153.             print("sum_var val: {d}\n", .{tape.vars[sum_var_id].val});
  154.             print("prod_var val: {d}\n", .{tape.vars[prod_var_id].val});
  155.             print("softplus_var val: {d}\n", .{tape.vars[softplus_var_id].val});
  156.             print("sum_var grad: {d}\n", .{tape.vars[sum_var_id].grad});
  157.             print("var1 grad: {d}\n", .{tape.vars[var1_id].grad});
  158.             print("var2 grad: {d}\n", .{tape.vars[var2_id].grad});
  159.         }
  160.     }
  161.     const end_time = tMilli();
  162.     const elapsed_time = @as(f64, @floatFromInt(end_time - start_time));
  163.     print("\nElapsed time: {d:.3} ms\n", .{elapsed_time});
  164. }
  165.  
Advertisement
Add Comment
Please, Sign In to add comment