Advertisement
Guest User

Untitled

a guest
Aug 25th, 2019
86
0
Never
Not a member of Pastebin yet? Sign Up, it unlocks many cool features!
text 0.42 KB | None | 0 0
  1. from jax import jvp, grad
  2.  
  3. def f(x,y):
  4. return x + y**2
  5.  
  6. def freeze(f, argnum, val):
  7. def _f(arg):
  8. args = [val, arg] if argnum == 0 else [arg, val]
  9. return f(*args)
  10. return _f
  11.  
  12. def mixed_jvp(f, order, primals, tangents):
  13. frozen_func = freeze(grad(f, order[0]), argnum=order[1], val=primals[order[0]])
  14. return jvp(frozen_func, (primals[order[1]],), tangents)
  15.  
  16. mixed_jvp(f, order=(0,1), primals=(2., 3.), tangents=(1.,))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement