Advertisement
Not a member of Pastebin yet?
Sign Up,
it unlocks many cool features!
- from jax import jvp, grad
- def f(x,y):
- return x + y**2
- def freeze(f, argnum, val):
- def _f(arg):
- args = [val, arg] if argnum == 0 else [arg, val]
- return f(*args)
- return _f
- def mixed_jvp(f, order, primals, tangents):
- frozen_func = freeze(grad(f, order[0]), argnum=order[1], val=primals[order[0]])
- return jvp(frozen_func, (primals[order[1]],), tangents)
- mixed_jvp(f, order=(0,1), primals=(2., 3.), tangents=(1.,))
Advertisement
Add Comment
Please, Sign In to add comment
Advertisement