https://github.com/kach/jax.value_and_jacfwd
Provides an implementation of a missing primitive in JAX, value_and_jacfwd
https://github.com/kach/jax.value_and_jacfwd
automatic-differentiation jax
Last synced: 8 months ago
JSON representation
Provides an implementation of a missing primitive in JAX, value_and_jacfwd
- Host: GitHub
- URL: https://github.com/kach/jax.value_and_jacfwd
- Owner: kach
- License: mit
- Created: 2021-12-27T16:01:38.000Z (over 4 years ago)
- Default Branch: main
- Last Pushed: 2021-12-27T16:11:11.000Z (over 4 years ago)
- Last Synced: 2025-01-28T20:41:29.050Z (over 1 year ago)
- Topics: automatic-differentiation, jax
- Language: Python
- Homepage:
- Size: 3.91 KB
- Stars: 0
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.txt
- License: LICENSE
Awesome Lists containing this project
README
value_and_jacfwd.py
Copyright (c) 2021 Kartik Chandra; see MIT license attached
This patch adds a function jax.value_and_jacfwd, which is the
forward-mode version of jax.value_and_grad. It allows returning
the value of a function in addition to its derivative, so that
you don't need to evaluate the function twice to get both the
value and its derivative as you would using plain jax.jacfwd.
For example:
>>> import jax, value_and_jacfwd
>>> def g(x):
>>> return (x ** 2).sum()
>>> dg = jax.value_and_jacfwd(g, has_aux=False)
>>> y, dg = dg(np.arange(3) * 1.)
>>> print(f'g(x) = {y}')
g(x) = 5.0
>>> print(f'dg(x) = {dg}')
dg(x) = [0. 2. 4.]
You can also export auxiliary values using the has_aux=True parameter,
again by analogy to jax.value_and_grad. For example:
>>> import jax, value_and_jacfwd
>>> def f(x):
>>> return (x ** 2).sum(), x.sum()
>>> df = jax.value_and_jacfwd(f, has_aux=True)
>>> (y, aux), df = df(np.arange(3) * 1.)
>>> print(f'f(x) = {y}')
f(x) = 5.0
>>> print(f'df(x) = {df}')
df(x) = [0. 2. 4.]
>>> print(f'aux = {aux}')
aux = 3.0
This patch addresses the following Github issue:
https://github.com/google/jax/pull/762