hedgefund

codon_autograd_v02a

Jan 16th, 2025
41
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
Python 3.63 KB | Source Code | 0 0
  1. from typing import List, Tuple
  2. import math
  3. import random
  4. from time import time
  5.  
  6. MAX_VARS = 1000
  7. MAX_OPS = 1000
  8. MAX_INPUTS = 2
  9.  
  10.  
  11. class OpType:
  12.     OP_SUM = 0
  13.     OP_PROD = 1
  14.     OP_SOFTPLUS = 2
  15.  
  16.  
  17. class NaiveVar:
  18.     val: float
  19.     grad: float
  20.  
  21.     def __init__(self, val: float) -> None:
  22.         self.val = val
  23.         self.grad = 0.0
  24.  
  25.  
  26. class Operation:
  27.     op_type: int
  28.     input_ids: Tuple[int, int]
  29.     output_id: int
  30.  
  31.     def __init__(self, op_type: int, input_ids: Tuple[int, int], output_id: int) -> None:
  32.         self.op_type = op_type
  33.         self.input_ids = input_ids
  34.         self.output_id = output_id
  35.  
  36.  
  37. class NaiveTape:
  38.     vars: List[NaiveVar]
  39.     ops: List[Operation]
  40.  
  41.     def __init__(self):
  42.         self.vars = []
  43.         self.ops = []
  44.  
  45.  
  46.     def create_var(self, val: float) -> int:
  47.         var = NaiveVar(val)
  48.         id = len(self.vars)
  49.         self.vars.append(var)
  50.         return id
  51.  
  52.     def add_operation(self, op: Operation) -> None:
  53.         self.ops.append(op)
  54.  
  55.     def backward(self, var_id: int) -> None:
  56.         self.vars[var_id].grad = 1.0
  57.         i = len(self.ops)
  58.  
  59.         while i > 0:
  60.             i -= 1
  61.             op = self.ops[i]
  62.             output_grad = self.vars[op.output_id].grad
  63.  
  64.             if op.op_type == OpType.OP_SUM:
  65.                 self.vars[op.input_ids[0]].grad += output_grad
  66.                 self.vars[op.input_ids[1]].grad += output_grad
  67.  
  68.             elif op.op_type == OpType.OP_PROD:
  69.                 input1_val = self.vars[op.input_ids[0]].val
  70.                 input2_val = self.vars[op.input_ids[1]].val
  71.                 self.vars[op.input_ids[0]].grad += input2_val * output_grad
  72.                 self.vars[op.input_ids[1]].grad += input1_val * output_grad
  73.  
  74.             elif op.op_type == OpType.OP_SOFTPLUS:
  75.                 input_val = self.vars[op.input_ids[0]].val
  76.                 exp_val = math.exp(-input_val)
  77.                 self.vars[op.input_ids[0]].grad += output_grad / (1.0 + exp_val)
  78.  
  79.  
  80.  
  81. def main():
  82.     iterations: int = 1000000
  83.     # random.seed(time.time())
  84.     start_time = time()
  85.  
  86.     for i in range(iterations):
  87.         tape = NaiveTape()
  88.         var1_id = tape.create_var(random.random())
  89.         var2_id = tape.create_var(random.random())
  90.  
  91.         # Inlined sum
  92.         sum_val = tape.vars[var1_id].val + tape.vars[var2_id].val
  93.         sum_var_id = tape.create_var(sum_val)
  94.         op_sum = Operation(OpType.OP_SUM, (var1_id, var2_id), sum_var_id)
  95.         tape.add_operation(op_sum)
  96.  
  97.         # Inlined prod
  98.         prod_val = tape.vars[sum_var_id].val * tape.vars[sum_var_id].val
  99.         prod_var_id = tape.create_var(prod_val)
  100.         op_prod = Operation(OpType.OP_PROD, (sum_var_id, sum_var_id), prod_var_id)
  101.         tape.add_operation(op_prod)
  102.  
  103.         # Inlined softplus
  104.         softplus_val = math.log1p(math.exp(tape.vars[prod_var_id].val))
  105.         softplus_var_id = tape.create_var(softplus_val)
  106.         op_softplus = Operation(OpType.OP_SOFTPLUS, (prod_var_id, 0), softplus_var_id)
  107.         tape.add_operation(op_softplus)
  108.  
  109.         tape.backward(softplus_var_id)
  110.  
  111.         if i == iterations - 1:
  112.             print(f"sum_var val: {tape.vars[sum_var_id].val}")
  113.             print(f"prod_var val: {tape.vars[prod_var_id].val}")
  114.             print(f"softplus_var val: {tape.vars[softplus_var_id].val}")
  115.             print(f"sum_var grad: {tape.vars[sum_var_id].grad}")
  116.             print(f"var1 grad: {tape.vars[var1_id].grad}")
  117.             print(f"var2 grad: {tape.vars[var2_id].grad}")
  118.  
  119.     end_time = time()
  120.     elapsed_time = (end_time - start_time) * 1000
  121.     print(f"\nElapsed time: {elapsed_time:.3f} ms")
  122.  
  123.  
  124. main()
  125.  
Advertisement
Add Comment
Please, Sign In to add comment