https://github.com/quantecon/jaxinterp
JAX compatible multilinear interpolation
https://github.com/quantecon/jaxinterp
Last synced: 12 months ago
JSON representation
JAX compatible multilinear interpolation
- Host: GitHub
- URL: https://github.com/quantecon/jaxinterp
- Owner: QuantEcon
- Created: 2023-06-13T00:55:42.000Z (about 3 years ago)
- Default Branch: main
- Last Pushed: 2023-06-13T23:40:36.000Z (almost 3 years ago)
- Last Synced: 2025-06-29T09:33:14.772Z (12 months ago)
- Size: 3.91 KB
- Stars: 8
- Watchers: 6
- Forks: 2
- Open Issues: 3
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# JAX compatible multilinear interpolation
We need fast linear interpolation in arbitrary dimensions, compatible with
`jax.jit`.
Suppose we have a function $f$ mapping $\mathbb R^n$ to $\mathbb R$ and we evalute $f$ on a
finite set of grid points $v_1, \ldots, v_k$, where each $v_i$ is a point in
$\mathbb R^n$.
Let $f_i = f(v_i)$.
We now want to be able to evaluate $\hat f(x)$, which is a linear interpolation
of the grid points $v_1, \ldots, v_k$ and corresponding function values
$f_1, \ldots, f_k$, at the point $x$.
We want to do this where $n$ can be any integer.
We also want to be able to do this in a vectorized, parallelized manner, so that
we can pass a large number of evaluation points $x_1, \ldots, x_m$, and compute
the points $\hat f(x_1), \ldots, \hat f(x_m)$ efficiently.
This is already possible in Numba, using the Econforge
[interpolation library](https://www.econforge.org/interpolation.py/):
```python
import numpy as np
from interpolation.splines import UCGrid, CGrid, nodes
from interpolation.splines import eval_linear
# function to interpolate
f = lambda x,y: np.sin(np.sqrt(x**2+y**2+0.00001))/np.sqrt(x**2+y**2+0.00001)
# uniform cartesian grid with a small number of points
grid = UCGrid((-1.0, 1.0, 10), (-1.0, 1.0, 10))
# get grid points
gp = nodes(grid) # 100x2 matrix
# compute values on grid points
values = f(gp[:,0], gp[:,1]).reshape((10,10))
# interpolate at one point
point = np.array([0.1,0.45]) # 1d array
val = eval_linear(grid, values, point) # float
# interpolate at many points:
points = np.random.random((10000,2))
eval_linear(grid, values, points) # 10000 vector
```
The question is, what is the most efficient way to replicate this in JAX?
Using `jax.scipy.ndimage.map_coordinates`, @junnanZ has implemented evaluation at a single point for $n=4$ in [this file](https://github.com/jstac/sdfs_via_autodiff/blob/main/code/ssy/continuous_junnan/utils.py).
We could then extend this from one `point` to `points` via `vmap`.
This discussion suggests that the vmap approach will be efficient: https://github.com/google/jax/issues/6312
Other alternatives include https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/math/batch_interp_regular_nd_grid
Are these our best options? If so, which is best?
Once we have decided, let's
1. test to make sure that the output agrees with the output from the Python code
above, and
1. provide a nice interface, like the one in the Python code above.