Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/patrick-kidger/diffrax?ref=zzun.app
Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
https://github.com/patrick-kidger/diffrax?ref=zzun.app
deep-learning differential-equations dynamical-systems equinox jax machine-learning neural-differential-equations neural-networks
Last synced: 3 months ago
JSON representation
Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. https://docs.kidger.site/diffrax/
- Host: GitHub
- URL: https://github.com/patrick-kidger/diffrax?ref=zzun.app
- Owner: patrick-kidger
- License: apache-2.0
- Created: 2021-07-25T23:42:12.000Z (over 3 years ago)
- Default Branch: main
- Last Pushed: 2024-08-02T01:30:23.000Z (3 months ago)
- Last Synced: 2024-08-02T06:11:47.067Z (3 months ago)
- Topics: deep-learning, differential-equations, dynamical-systems, equinox, jax, machine-learning, neural-differential-equations, neural-networks
- Language: Python
- Homepage:
- Size: 6.08 MB
- Stars: 1,337
- Watchers: 20
- Forks: 123
- Open Issues: 149
-
Metadata Files:
- Readme: README.md
- Contributing: CONTRIBUTING.md
- Funding: .github/FUNDING.yml
- License: LICENSE
Awesome Lists containing this project
- awesome-sciml - patrick-kidger/diffrax at zzun.app
README
Diffrax
Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.
Diffrax is a [JAX](https://github.com/google/jax)-based library providing numerical differential equation solvers.
Features include:
- ODE/SDE/CDE (ordinary/stochastic/controlled) solvers;
- lots of different solvers (including `Tsit5`, `Dopri8`, symplectic solvers, implicit solvers);
- vmappable _everything_ (including the region of integration);
- using a PyTree as the state;
- dense solutions;
- multiple adjoint methods for backpropagation;
- support for neural differential equations._From a technical point of view, the internal structure of the library is pretty cool -- all kinds of equations (ODEs, SDEs, CDEs) are solved in a unified way (rather than being treated separately), producing a small tightly-written library._
## Installation
```
pip install diffrax
```Requires Python 3.9+, JAX 0.4.13+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.10.11+.
## Documentation
Available at [https://docs.kidger.site/diffrax](https://docs.kidger.site/diffrax).
## Quick example
```python
from diffrax import diffeqsolve, ODETerm, Dopri5
import jax.numpy as jnpdef f(t, y, args):
return -yterm = ODETerm(f)
solver = Dopri5()
y0 = jnp.array([2., 3.])
solution = diffeqsolve(term, solver, t0=0, t1=1, dt0=0.1, y0=y0)
```Here, `Dopri5` refers to the Dormand--Prince 5(4) numerical differential equation solver, which is a standard choice for many problems.
## Citation
If you found this library useful in academic research, please cite: [(arXiv link)](https://arxiv.org/abs/2202.02435)
```bibtex
@phdthesis{kidger2021on,
title={{O}n {N}eural {D}ifferential {E}quations},
author={Patrick Kidger},
year={2021},
school={University of Oxford},
}
```(Also consider starring the project on GitHub.)
## See also: other libraries in the JAX ecosystem
**Always useful**
[Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX!
[jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays.**Deep learning**
[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).**Scientific computing**
[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
[Lineax](https://github.com/patrick-kidger/lineax): linear solvers.
[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.
[sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.
[PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)**Awesome JAX**
[Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects.