Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- const std = @import("std");
- const testing = std.testing;
- const print = std.debug.print;
- const tMilli = std.time.milliTimestamp;
- const rand = std.rand;
- const MAX_VARS = 1000;
- const MAX_OPS = 1000;
- const MAX_INPUTS = 2;
- const OpType = enum(u2) {
- Sum,
- Prod,
- Softplus,
- };
- const NaiveVar = struct {
- val: f64,
- grad: f64,
- id: usize,
- is_active: bool,
- };
- const Operation = struct {
- op_type: OpType,
- input_ids: [MAX_INPUTS]usize,
- output_id: usize,
- };
- const NaiveTape = struct {
- vars: [MAX_VARS]NaiveVar,
- var_count: usize,
- ops: [MAX_OPS]Operation,
- op_count: usize,
- fn init() NaiveTape {
- return .{
- .vars = undefined,
- .var_count = 0,
- .ops = undefined,
- .op_count = 0,
- };
- }
- fn createVar(self: *NaiveTape, val: f64) usize {
- const id = self.var_count;
- self.vars[id] = NaiveVar {
- .val = val,
- .grad = 0.0,
- .id = id,
- .is_active = true,
- };
- self.var_count += 1;
- return id;
- }
- fn addOperation(self: *NaiveTape, op: Operation) void {
- self.ops[self.op_count] = op;
- self.op_count += 1;
- }
- fn sum(self: *NaiveTape, input1_id: usize, input2_id: usize) usize {
- const sum_val = self.vars[input1_id].val + self.vars[input2_id].val;
- const result_id = self.createVar(sum_val);
- const op = Operation {
- .op_type = .Sum,
- .input_ids = .{ input1_id, input2_id },
- .output_id = result_id,
- };
- self.addOperation(op);
- return result_id;
- }
- fn prod(self: *NaiveTape, var1_id: usize, var2_id: usize) usize {
- const prod_val = self.vars[var1_id].val * self.vars[var2_id].val;
- const result_id = self.createVar(prod_val);
- const op = Operation{
- .op_type = .Prod,
- .input_ids = .{ var1_id, var2_id },
- .output_id = result_id,
- };
- self.addOperation(op);
- return result_id;
- }
- fn softplus(self: *NaiveTape, nvar_id: usize) usize {
- const softplus_val = std.math.log1p(std.math.exp(self.vars[nvar_id].val));
- const result_id = self.createVar(softplus_val);
- const op = Operation{
- .op_type = .Softplus,
- .input_ids = .{ nvar_id, 0},
- .output_id = result_id,
- };
- self.addOperation(op);
- return result_id;
- }
- fn backward(self: *NaiveTape, var_id: usize) void {
- self.vars[var_id].grad = 1.0;
- var i = self.op_count;
- while (i > 0) {
- i -= 1;
- const op = self.ops[i];
- const output_grad = self.vars[op.output_id].grad;
- switch (op.op_type) {
- .Sum => {
- self.vars[op.input_ids[0]].grad += output_grad;
- self.vars[op.input_ids[1]].grad += output_grad;
- },
- .Prod => {
- const input1_val = self.vars[op.input_ids[0]].val;
- const input2_val = self.vars[op.input_ids[1]].val;
- self.vars[op.input_ids[0]].grad += input2_val * output_grad;
- self.vars[op.input_ids[1]].grad += input1_val * output_grad;
- },
- .Softplus => {
- const input_val = self.vars[op.input_ids[0]].val;
- const exp_val = std.math.exp(-input_val);
- self.vars[op.input_ids[0]].grad += output_grad / (1.0 + exp_val);
- },
- }
- }
- }
- };
- pub fn main() !void {
- const iterations: usize = 1000000;
- var rng = rand.DefaultPrng.init(@as(u64, @bitCast(tMilli())));
- const start_time = tMilli();
- var i: usize = 0;
- while (i < iterations) : (i += 1) {
- var tape = NaiveTape.init();
- // Reset the variables to 0
- tape.var_count = 0;
- tape.op_count = 0;
- const var1_id = tape.createVar(rng.random().float(f64));
- const var2_id = tape.createVar(rng.random().float(f64));
- const sum_var_id = tape.sum(var1_id, var2_id);
- const prod_var_id = tape.prod(sum_var_id, sum_var_id);
- const softplus_var_id = tape.softplus(prod_var_id);
- tape.backward(softplus_var_id);
- if (i == iterations - 1) {
- print("sum_var val: {d}\n", .{tape.vars[sum_var_id].val});
- print("prod_var val: {d}\n", .{tape.vars[prod_var_id].val});
- print("softplus_var val: {d}\n", .{tape.vars[softplus_var_id].val});
- print("sum_var grad: {d}\n", .{tape.vars[sum_var_id].grad});
- print("var1 grad: {d}\n", .{tape.vars[var1_id].grad});
- print("var2 grad: {d}\n", .{tape.vars[var2_id].grad});
- }
- }
- const end_time = tMilli();
- const elapsed_time = @as(f64, @floatFromInt(end_time - start_time));
- print("\nElapsed time: {d:.3} ms\n", .{elapsed_time});
- }
Advertisement
Add Comment
Please, Sign In to add comment