https://github.com/galacticdynamics/diffraxtra
Extras for Diffrax: OOP and vectorization
https://github.com/galacticdynamics/diffraxtra
deep-learning differential-equations diffrax dynamical-systems equinox jax machine-learning neural-differential-equations neural-networks
Last synced: 8 months ago
JSON representation
Extras for Diffrax: OOP and vectorization
- Host: GitHub
- URL: https://github.com/galacticdynamics/diffraxtra
- Owner: GalacticDynamics
- License: mit
- Created: 2025-02-04T19:16:39.000Z (8 months ago)
- Default Branch: main
- Last Pushed: 2025-02-04T22:20:40.000Z (8 months ago)
- Last Synced: 2025-02-04T22:25:01.630Z (8 months ago)
- Topics: deep-learning, differential-equations, diffrax, dynamical-systems, equinox, jax, machine-learning, neural-differential-equations, neural-networks
- Language: Python
- Homepage:
- Size: 75.2 KB
- Stars: 1
- Watchers: 2
- Forks: 1
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- Contributing: .github/CONTRIBUTING.md
- License: LICENSE
Awesome Lists containing this project
README
diffraxtra
diffrax
extras
---
Extras for [diffrax][diffrax-link].
- `DiffEqSolver`: an object-oriented interface to `diffrax.diffeqsolve`.
- `VectorizedDenseInterpolation`: a vectorized form of
`diffrax.DenseInterpolation` that works on batched results from
`diffrax.diffeqsolve`.For example,
```python
import jax.numpy as jnp
import diffrax as dfx
from diffraxtra import DiffEqSolver# Construct a solver object.
solver = DiffEqSolver(dfx.Dopri5(),
stepsize_controller=dfx.PIDController(rtol=1e-5, atol=1e-5))# And a differential equation to solve.
term = dfx.ODETerm(lambda t, y, args: -y)# Then solve the differential equation.
saveat = dfx.SaveAt(t1=True, dense=True)
soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
vectorize_interpolation=True)print(soln)
# Solution(
# t0=f32[], t1=f32[], ts=f32[1],
# ys=f32[1],
# interpolation=VectorizedDenseInterpolation(
# scalar_interpolation=DenseInterpolation( ... ),
# batch_shape=(),
# y0_shape=()
# ),
# ...
# )soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
# Array([[0.90483742, 0.81872516],
# [0.74080871, 0.67031456]], dtype=float64)```
## Installation
[![PyPI platforms][pypi-platforms]][pypi-link]
[![PyPI version][pypi-version]][pypi-link]```bash
pip install diffraxtra
```## Documentation
### `DiffEqSolver`
```pycon
>>> import jax.numpy as jnp
>>> import diffrax as dfx
>>> from diffraxtra import DiffEqSolver```
Construct a solver object.
```pycon
>>> solver = DiffEqSolver(dfx.Dopri5(),
... stepsize_controller=dfx.PIDController(rtol=1e-5, atol=1e-5))```
And a differential equation to solve.
```pycon
>>> term = dfx.ODETerm(lambda t, y, args: -y)```
Then solve the differential equation.
```pycon
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[1],
ys=f64[1], ... )```
The solution can be saved at specific times.
```pycon
>>> saveat = dfx.SaveAt(ts=[0., 1., 2., 3.])
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[4],
ys=f64[4], ... )```
The solution can be densely interpolated.
```pycon
>>> saveat = dfx.SaveAt(t1=True, dense=True)
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln
Solution( t0=f64[], t1=f64[], ts=f64[1],
ys=f64[1], ... )
>>> soln.evaluate(0.5).round(3)
Array(0.607, dtype=float64)```
Using the `VectorizedDenseInterpolation` class, the interpolation can be
vectorized, enabling evaluation of batched solutions over batches of times.```pycon
>>> from diffraxtra import VectorizedDenseInterpolation
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)
>>> soln = VectorizedDenseInterpolation.apply_to_solution(soln)
>>> soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
Array([[0.90483742, 0.81872516],
[0.74080871, 0.67031456]], dtype=float64)```
This can be more conveniently done using the `vectorize_interpolation` argument.
```pycon
>>> soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
... vectorize_interpolation=True)
>>> soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))
Array([[0.90483742, 0.81872516],
[0.74080871, 0.67031456]], dtype=float64)```
There are many ways to construct a `DiffEqSolver` object. For example, we can
can make a new one from an existing `DiffEqSolver` object```pycon
>>> solver = DiffEqSolver(dfx.Dopri5())
>>> DiffEqSolver.from_(solver) is solver
True```
From a `diffrax.AbstractSolver` object.
```pycon
>>> solver = DiffEqSolver.from_(dfx.Dopri5())
>>> solver
DiffEqSolver(
solver=Dopri5(scan_kind=None),
stepsize_controller=ConstantStepSize(),
adjoint=RecursiveCheckpointAdjoint(checkpoints=None),
event=None,
max_steps=4096
)```
From a `collections.abc.Mapping`
```pycon
>>> solver = DiffEqSolver.from_({"solver": dfx.Dopri5(),
... "stepsize_controller": dfx.PIDController(rtol=1e-5, atol=1e-5)})
>>> solver
DiffEqSolver(
solver=Dopri5(scan_kind=None),
stepsize_controller=PIDController( ... ),
adjoint=RecursiveCheckpointAdjoint(checkpoints=None),
event=None,
max_steps=4096
)```
For a full enumeration of the ways to construct a `DiffEqSolver` object, see
`diffraxtra.DiffEqSolver.from_`.### `VectorizedDenseInterpolation`
Vectorized wrapper around a `diffrax.DenseInterpolation`
This also works on non-batched interpolations.
```pycon
>>> import jax
>>> import jax.numpy as jnp
>>> import diffrax as dfx```
We'll start with a non-batched interpolation:
```pycon
>>> vector_field = lambda t, y, args: -y
>>> term = dfx.ODETerm(vector_field)
>>> solver = dfx.Dopri5()
>>> ts = jnp.array([0.0, 1, 2, 3])
>>> saveat = dfx.SaveAt(ts=ts, dense=True)
>>> stepsize_controller = dfx.PIDController(rtol=1e-5, atol=1e-5)>>> sol = dfx.diffeqsolve(
... term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,
... stepsize_controller=stepsize_controller)
>>> interp = VectorizedDenseInterpolation(sol.interpolation)
>>> interp
VectorizedDenseInterpolation(
scalar_interpolation=DenseInterpolation(
ts=f64[1,4097],
ts_size=weak_i64[1],
infos={'k': f64[1,4096,7], 'y0': f64[1,4096], 'y1': f64[1,4096]},
interpolation_cls=,
direction=weak_i64[1],
t0_if_trivial=f64[1],
y0_if_trivial=f64[1]
),
batch_shape=(),
y0_shape=()
)```
This can be evaluated by the normal means:
```pycon
>>> interp.evaluate(ts[-1]) # scalar evaluation
Array(0.04978961, dtype=float64)```
It also works on arrays, without needed to manually apply `jax.vmap`:
```pycon
>>> interp.evaluate(ts) # It works on arrays!
Array([1. , 0.36788338, 0.13533922, 0.04978961], dtype=float64)```
```pycon
>>> interp.evaluate(ts, ts[0]) # t1 - t0 mixed scalar and array
Array([0. , 0.63211662, 0.86466078, 0.95021039], dtype=float64)```
Better yet, the time array may be arbitrarily shaped:
```pycon
>>> interp.evaluate(ts.reshape(2, 2)).round(3)
Array([[1. , 0.368],
[0.135, 0.05 ]], dtype=float64)```
As a convenience, we can also apply the `VectorizedDenseInterpolation` to the
solution to modify the interpolation "in-place" (when in a jitted context,
otherwise out-of-place, returning a copy):```pycon
>>> sol = VectorizedDenseInterpolation.apply_to_solution(sol)
>>> isinstance(sol, dfx.Solution)
True
>>> isinstance(sol.interpolation, VectorizedDenseInterpolation)
True```
Now we'll batch the interpolation:
```pycon
>>> @jax.vmap
... def solve(y0):
... sol = dfx.diffeqsolve(
... term, solver, t0=0, t1=3, dt0=0.1, y0=y0, saveat=saveat,
... stepsize_controller=stepsize_controller)
... return sol
>>> sol = solve(jnp.array([1, 2, 3]))
>>> interp = VectorizedDenseInterpolation(sol.interpolation)```
```pycon
>>> interp.evaluate(ts[-1]).round(3) # scalar eval of batched interp
Array([0.05 , 0.1 , 0.149], dtype=float64)```
```pycon
>>> interp.evaluate(ts).astype(jnp.float64).round(3) # array eval of batched interp
Array([[1. , 0.368, 0.135, 0.05 ],
[2. , 0.736, 0.271, 0.1 ],
[3. , 1.104, 0.406, 0.149]], dtype=float64)```
```pycon
>>> interp.evaluate(ts, ts[0]).round(3) # mixed scalar and array eval
Array([[0. , 0.632, 0.865, 0.95 ],
[0. , 1.264, 1.729, 1.9 ],
[0. , 1.896, 2.594, 2.851]], dtype=float64)```
```pycon
>>> ys = interp.evaluate(ts.reshape(2, 2)).round(3) # arbitrary shape eval
>>> ys
Array([[[1. , 0.368],
[0.135, 0.05 ]],
[[2. , 0.736],
[0.271, 0.1 ]],
[[3. , 1.104],
[0.406, 0.149]]], dtype=float64)
>>> ys.shape # (batch, *times)
(3, 2, 2)```
## Citation
[![DOI][zenodo-badge]][zenodo-link]
If you enjoyed using this library and would like to cite the software you use
then click the link above.## Development
[![Actions Status][actions-badge]][actions-link]
[![codecov][codecov-badge]][codecov-link]
[![SPEC 0 — Minimum Supported Dependencies][spec0-badge]][spec0-link]
[![pre-commit][pre-commit-badge]][pre-commit-link]
[![ruff][ruff-badge]][ruff-link]We welcome contributions!
[actions-badge]: https://github.com/GalacticDynamics/diffraxtra/workflows/CI/badge.svg
[actions-link]: https://github.com/GalacticDynamics/diffraxtra/actions
[codecov-badge]: https://codecov.io/gh/GalacticDynamics/diffraxtra/graph/badge.svg
[codecov-link]: https://codecov.io/gh/GalacticDynamics/diffraxtra
[pre-commit-badge]: https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit
[pre-commit-link]: https://pre-commit.com
[pypi-link]: https://pypi.org/project/diffraxtra/
[pypi-platforms]: https://img.shields.io/pypi/pyversions/diffraxtra
[pypi-version]: https://img.shields.io/pypi/v/diffraxtra
[ruff-badge]: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json
[ruff-link]: https://docs.astral.sh/ruff/
[spec0-badge]: https://img.shields.io/badge/SPEC-0-green?labelColor=%23004811&color=%235CA038
[spec0-link]: https://scientific-python.org/specs/spec-0000/
[zenodo-badge]: https://zenodo.org/badge/DOI/10.5281/zenodo.14806581.svg
[zenodo-link]: https://zenodo.org/doi/10.5281/zenodo.14806581[diffrax-link]: https://docs.kidger.site/diffrax/