SHARE
TWEET

Untitled

a guest Aug 25th, 2019 63 Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
  1. from jax import jvp, grad
  2.  
  3. def f(x,y):
  4.   return x + y**2
  5.  
  6. def freeze(f, argnum, val):
  7.   def _f(arg):
  8.     args = [val, arg] if argnum == 0 else [arg, val]
  9.     return f(*args)
  10.   return _f
  11.  
  12. def mixed_jvp(f, order, primals, tangents):
  13.   frozen_func = freeze(grad(f, order[0]), argnum=order[1], val=primals[order[0]])
  14.   return jvp(frozen_func, (primals[order[1]],), tangents)
  15.  
  16. mixed_jvp(f, order=(0,1), primals=(2., 3.), tangents=(1.,))
RAW Paste Data
We use cookies for various purposes including analytics. By continuing to use Pastebin, you agree to our use of cookies as described in the Cookies Policy. OK, I Understand
 
Top