{"id":25257002,"url":"https://github.com/galacticdynamics/diffraxtra","last_synced_at":"2025-10-27T02:31:38.843Z","repository":{"id":275845581,"uuid":"927326151","full_name":"GalacticDynamics/diffraxtra","owner":"GalacticDynamics","description":"Extras for Diffrax: OOP and vectorization","archived":false,"fork":false,"pushed_at":"2025-02-04T22:20:40.000Z","size":77,"stargazers_count":1,"open_issues_count":0,"forks_count":1,"subscribers_count":2,"default_branch":"main","last_synced_at":"2025-02-04T22:25:01.630Z","etag":null,"topics":["deep-learning","differential-equations","diffrax","dynamical-systems","equinox","jax","machine-learning","neural-differential-equations","neural-networks"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/GalacticDynamics.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":".github/CONTRIBUTING.md","funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2025-02-04T19:16:39.000Z","updated_at":"2025-02-04T22:23:22.000Z","dependencies_parsed_at":"2025-02-04T22:35:38.269Z","dependency_job_id":null,"html_url":"https://github.com/GalacticDynamics/diffraxtra","commit_stats":null,"previous_names":["galacticdynamics/diffraxtra"],"tags_count":1,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/GalacticDynamics%2Fdiffraxtra","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/GalacticDynamics%2Fdiffraxtra/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/GalacticDynamics%2Fdiffraxtra/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/GalacticDynamics%2Fdiffraxtra/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/GalacticDynamics","download_url":"https://codeload.github.com/GalacticDynamics/diffraxtra/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":238422978,"owners_count":19469679,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["deep-learning","differential-equations","diffrax","dynamical-systems","equinox","jax","machine-learning","neural-differential-equations","neural-networks"],"created_at":"2025-02-12T06:32:27.950Z","updated_at":"2025-10-27T02:31:38.838Z","avatar_url":"https://github.com/GalacticDynamics.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"\u003ch1 align='center'\u003e diffraxtra \u003c/h1\u003e\n\u003ch3 align=\"center\"\u003e\u003ccode\u003ediffrax\u003c/code\u003e extras\u003c/h3\u003e\n\n\u003cp align=\"center\"\u003e\n    \u003ca href=\"https://pypi.org/project/diffraxtra/\"\u003e \u003cimg alt=\"PyPI: diffraxtra\" src=\"https://img.shields.io/pypi/v/diffraxtra?style=flat\" /\u003e \u003c/a\u003e\n    \u003ca href=\"https://pypi.org/project/diffraxtra/\"\u003e \u003cimg alt=\"PyPI versions: diffraxtra\" src=\"https://img.shields.io/pypi/pyversions/diffraxtra\" /\u003e \u003c/a\u003e\n    \u003ca href=\"https://pypi.org/project/diffraxtra/\"\u003e \u003cimg alt=\"diffraxtra license\" src=\"https://img.shields.io/github/license/GalacticDynamics/diffraxtra\" /\u003e \u003c/a\u003e\n\u003c/p\u003e\n\u003cp align=\"center\"\u003e\n    \u003ca href=\"https://github.com/GalacticDynamics/diffraxtra/actions\"\u003e \u003cimg alt=\"CI status\" src=\"https://github.com/GalacticDynamics/diffraxtra/workflows/CI/badge.svg\" /\u003e \u003c/a\u003e\n    \u003ca href=\"https://codecov.io/gh/GalacticDynamics/diffraxtra\"\u003e \u003cimg alt=\"codecov\" src=\"https://codecov.io/gh/GalacticDynamics/diffraxtra/graph/badge.svg\" /\u003e \u003c/a\u003e\n    \u003ca href=\"https://scientific-python.org/specs/spec-0000/\"\u003e \u003cimg alt=\"ruff\" src=\"https://img.shields.io/badge/SPEC-0-green?labelColor=%23004811\u0026color=%235CA038\" /\u003e \u003c/a\u003e\n    \u003ca href=\"https://docs.astral.sh/ruff/\"\u003e \u003cimg alt=\"ruff\" src=\"https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json\" /\u003e \u003c/a\u003e\n    \u003ca href=\"https://pre-commit.com\"\u003e \u003cimg alt=\"pre-commit\" src=\"https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit\" /\u003e \u003c/a\u003e\n\u003c/p\u003e\n\n---\n\nExtras for [diffrax][diffrax-link].\n\n- `DiffEqSolver`: an object-oriented interface to `diffrax.diffeqsolve`.\n- `VectorizedDenseInterpolation`: a vectorized form of\n  `diffrax.DenseInterpolation` that works on batched results from\n  `diffrax.diffeqsolve`.\n\nFor example,\n\n\u003c!-- invisible-code-block: python\nimport jax\njax.config.update(\"jax_enable_x64\", True)\n--\u003e\n\n```python\nimport jax.numpy as jnp\nimport diffrax as dfx\nfrom diffraxtra import DiffEqSolver\n\n# Construct a solver object.\nsolver = DiffEqSolver(dfx.Dopri5(),\n                      stepsize_controller=dfx.PIDController(rtol=1e-5, atol=1e-5))\n\n# And a differential equation to solve.\nterm = dfx.ODETerm(lambda t, y, args: -y)\n\n# Then solve the differential equation.\nsaveat = dfx.SaveAt(t1=True, dense=True)\nsoln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,\n              vectorize_interpolation=True)\n\nprint(soln)\n# Solution(\n#   t0=f32[], t1=f32[], ts=f32[1],\n#   ys=f32[1],\n#   interpolation=VectorizedDenseInterpolation(\n#     scalar_interpolation=DenseInterpolation( ... ),\n#     batch_shape=(),\n#     y0_shape=()\n#   ),\n#   ...\n# )\n\nsoln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))\n# Array([[0.90483742, 0.81872516],\n#         [0.74080871, 0.67031456]], dtype=float64)\n\n```\n\n## Installation\n\n[![PyPI platforms][pypi-platforms]][pypi-link]\n[![PyPI version][pypi-version]][pypi-link]\n\n```bash\npip install diffraxtra\n```\n\n## Documentation\n\n### `DiffEqSolver`\n\n```pycon\n\u003e\u003e\u003e import jax.numpy as jnp\n\u003e\u003e\u003e import diffrax as dfx\n\u003e\u003e\u003e from diffraxtra import DiffEqSolver\n\n```\n\nConstruct a solver object.\n\n```pycon\n\u003e\u003e\u003e solver = DiffEqSolver(dfx.Dopri5(),\n...                stepsize_controller=dfx.PIDController(rtol=1e-5, atol=1e-5))\n\n```\n\nAnd a differential equation to solve.\n\n```pycon\n\u003e\u003e\u003e term = dfx.ODETerm(lambda t, y, args: -y)\n\n```\n\nThen solve the differential equation.\n\n```pycon\n\u003e\u003e\u003e soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1)\n\u003e\u003e\u003e soln\nSolution( t0=f64[], t1=f64[], ts=f64[1],\n          ys=f64[1], ... )\n\n```\n\nThe solution can be saved at specific times.\n\n```pycon\n\u003e\u003e\u003e saveat = dfx.SaveAt(ts=[0., 1., 2., 3.])\n\u003e\u003e\u003e soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)\n\u003e\u003e\u003e soln\nSolution( t0=f64[], t1=f64[], ts=f64[4],\n          ys=f64[4], ... )\n\n```\n\nThe solution can be densely interpolated.\n\n```pycon\n\u003e\u003e\u003e saveat = dfx.SaveAt(t1=True, dense=True)\n\u003e\u003e\u003e soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)\n\u003e\u003e\u003e soln\nSolution( t0=f64[], t1=f64[], ts=f64[1],\n          ys=f64[1], ... )\n\u003e\u003e\u003e soln.evaluate(0.5).round(3)\nArray(0.607, dtype=float64)\n\n```\n\nUsing the `VectorizedDenseInterpolation` class, the interpolation can be\nvectorized, enabling evaluation of batched solutions over batches of times.\n\n```pycon\n\u003e\u003e\u003e from diffraxtra import VectorizedDenseInterpolation\n\u003e\u003e\u003e soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat)\n\u003e\u003e\u003e soln = VectorizedDenseInterpolation.apply_to_solution(soln)\n\u003e\u003e\u003e soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))\nArray([[0.90483742, 0.81872516],\n       [0.74080871, 0.67031456]], dtype=float64)\n\n```\n\nThis can be more conveniently done using the `vectorize_interpolation` argument.\n\n```pycon\n\u003e\u003e\u003e soln = solver(term, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,\n...               vectorize_interpolation=True)\n\u003e\u003e\u003e soln.evaluate(jnp.array([0.1, 0.2, 0.3, 0.4]).reshape(2, 2))\nArray([[0.90483742, 0.81872516],\n       [0.74080871, 0.67031456]], dtype=float64)\n\n```\n\nThere are many ways to construct a `DiffEqSolver` object. For example, we can\ncan make a new one from an existing `DiffEqSolver` object\n\n```pycon\n\u003e\u003e\u003e solver = DiffEqSolver(dfx.Dopri5())\n\u003e\u003e\u003e DiffEqSolver.from_(solver) is solver\nTrue\n\n```\n\nFrom a `diffrax.AbstractSolver` object.\n\n```pycon\n\u003e\u003e\u003e solver = DiffEqSolver.from_(dfx.Dopri5())\n\u003e\u003e\u003e solver\nDiffEqSolver(solver=Dopri5())\n\n```\n\n(Where all other arguments are their default values and printed only if\nchanged.)\n\nFrom a `collections.abc.Mapping`\n\n```pycon\n\u003e\u003e\u003e solver = DiffEqSolver.from_({\"solver\": dfx.Dopri5(),\n...       \"stepsize_controller\": dfx.PIDController(rtol=1e-5, atol=1e-5)})\n\u003e\u003e\u003e solver\nDiffEqSolver(\n  solver=Dopri5(), stepsize_controller=PIDController(rtol=1e-05, atol=1e-05)\n)\n\n```\n\nFor a full enumeration of the ways to construct a `DiffEqSolver` object, see\n`diffraxtra.DiffEqSolver.from_`.\n\n### `VectorizedDenseInterpolation`\n\nVectorized wrapper around a `diffrax.DenseInterpolation`\n\nThis also works on non-batched interpolations.\n\n```pycon\n\u003e\u003e\u003e import jax\n\u003e\u003e\u003e import jax.numpy as jnp\n\u003e\u003e\u003e import diffrax as dfx\n\n```\n\nWe'll start with a non-batched interpolation:\n\n```pycon\n\u003e\u003e\u003e vector_field = lambda t, y, args: -y\n\u003e\u003e\u003e term = dfx.ODETerm(vector_field)\n\u003e\u003e\u003e solver = dfx.Dopri5()\n\u003e\u003e\u003e ts = jnp.array([0.0, 1, 2, 3])\n\u003e\u003e\u003e saveat = dfx.SaveAt(ts=ts, dense=True)\n\u003e\u003e\u003e stepsize_controller = dfx.PIDController(rtol=1e-5, atol=1e-5)\n\n\u003e\u003e\u003e sol = dfx.diffeqsolve(\n...     term, solver, t0=0, t1=3, dt0=0.1, y0=1, saveat=saveat,\n...     stepsize_controller=stepsize_controller)\n\u003e\u003e\u003e interp = VectorizedDenseInterpolation(sol.interpolation)\n\u003e\u003e\u003e interp\nVectorizedDenseInterpolation(\n  scalar_interpolation=DenseInterpolation(\n    ts=f64[1,4097],\n    ts_size=weak_i64[1],\n    infos={'k': f64[1,4096,7], 'y0': f64[1,4096], 'y1': f64[1,4096]},\n    interpolation_cls=diffrax._solver.dopri5._Dopri5Interpolation,\n    direction=weak_i64[1],\n    t0_if_trivial=f64[1],\n    y0_if_trivial=f64[1]\n  ),\n  batch_shape=(),\n  y0_shape=()\n)\n\n```\n\nThis can be evaluated by the normal means:\n\n```pycon\n\u003e\u003e\u003e interp.evaluate(ts[-1])  # scalar evaluation\nArray(0.04978961, dtype=float64)\n\n```\n\nIt also works on arrays, without needed to manually apply `jax.vmap`:\n\n```pycon\n\u003e\u003e\u003e interp.evaluate(ts)  # It works on arrays!\nArray([1. , 0.36788338, 0.13533922, 0.04978961], dtype=float64)\n\n```\n\n```pycon\n\u003e\u003e\u003e interp.evaluate(ts, ts[0])  # t1 - t0 mixed scalar and array\nArray([0. , 0.63211662, 0.86466078, 0.95021039], dtype=float64)\n\n```\n\nBetter yet, the time array may be arbitrarily shaped:\n\n```pycon\n\u003e\u003e\u003e interp.evaluate(ts.reshape(2, 2)).round(3)\nArray([[1.   , 0.368],\n       [0.135, 0.05 ]], dtype=float64)\n\n```\n\nAs a convenience, we can also apply the `VectorizedDenseInterpolation` to the\nsolution to modify the interpolation \"in-place\" (when in a jitted context,\notherwise out-of-place, returning a copy):\n\n```pycon\n\u003e\u003e\u003e sol = VectorizedDenseInterpolation.apply_to_solution(sol)\n\u003e\u003e\u003e isinstance(sol, dfx.Solution)\nTrue\n\u003e\u003e\u003e isinstance(sol.interpolation, VectorizedDenseInterpolation)\nTrue\n\n```\n\nNow we'll batch the interpolation:\n\n```pycon\n\u003e\u003e\u003e @jax.vmap\n... def solve(y0):\n...     sol = dfx.diffeqsolve(\n...         term, solver, t0=0, t1=3, dt0=0.1, y0=y0, saveat=saveat,\n...         stepsize_controller=stepsize_controller)\n...     return sol\n\u003e\u003e\u003e sol = solve(jnp.array([1, 2, 3]))\n\u003e\u003e\u003e interp = VectorizedDenseInterpolation(sol.interpolation)\n\n```\n\n```pycon\n\u003e\u003e\u003e interp.evaluate(ts[-1]).round(3)  # scalar eval of batched interp\nArray([0.05 , 0.1  , 0.149], dtype=float64)\n\n```\n\n```pycon\n\u003e\u003e\u003e interp.evaluate(ts).astype(jnp.float64).round(3)  # array eval of batched interp\nArray([[1.   , 0.368, 0.135, 0.05 ],\n       [2.   , 0.736, 0.271, 0.1  ],\n       [3.   , 1.104, 0.406, 0.149]], dtype=float64)\n\n```\n\n```pycon\n\u003e\u003e\u003e interp.evaluate(ts, ts[0]).round(3)  # mixed scalar and array eval\nArray([[0.   , 0.632, 0.865, 0.95 ],\n       [0.   , 1.264, 1.729, 1.9  ],\n       [0.   , 1.896, 2.594, 2.851]], dtype=float64)\n\n```\n\n```pycon\n\u003e\u003e\u003e ys = interp.evaluate(ts.reshape(2, 2)).round(3)  # arbitrary shape eval\n\u003e\u003e\u003e ys\nArray([[[1.   , 0.368],\n        [0.135, 0.05 ]],\n        [[2.   , 0.736],\n        [0.271, 0.1  ]],\n        [[3.   , 1.104],\n        [0.406, 0.149]]], dtype=float64)\n\u003e\u003e\u003e ys.shape  # (batch, *times)\n(3, 2, 2)\n\n```\n\n## Citation\n\n[![DOI][zenodo-badge]][zenodo-link]\n\nIf you enjoyed using this library and would like to cite the software you use\nthen click the link above.\n\n## Development\n\n[![Actions Status][actions-badge]][actions-link]\n[![codecov][codecov-badge]][codecov-link]\n[![SPEC 0 — Minimum Supported Dependencies][spec0-badge]][spec0-link]\n[![pre-commit][pre-commit-badge]][pre-commit-link]\n[![ruff][ruff-badge]][ruff-link]\n\nWe welcome contributions!\n\n\u003c!-- prettier-ignore-start --\u003e\n[actions-badge]:            https://github.com/GalacticDynamics/diffraxtra/workflows/CI/badge.svg\n[actions-link]:             https://github.com/GalacticDynamics/diffraxtra/actions\n[codecov-badge]:            https://codecov.io/gh/GalacticDynamics/diffraxtra/graph/badge.svg\n[codecov-link]:             https://codecov.io/gh/GalacticDynamics/diffraxtra\n[pre-commit-badge]:         https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit\n[pre-commit-link]:          https://pre-commit.com\n[pypi-link]:                https://pypi.org/project/diffraxtra/\n[pypi-platforms]:           https://img.shields.io/pypi/pyversions/diffraxtra\n[pypi-version]:             https://img.shields.io/pypi/v/diffraxtra\n[ruff-badge]:               https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json\n[ruff-link]:                https://docs.astral.sh/ruff/\n[spec0-badge]:              https://img.shields.io/badge/SPEC-0-green?labelColor=%23004811\u0026color=%235CA038\n[spec0-link]:               https://scientific-python.org/specs/spec-0000/\n[zenodo-badge]:             https://zenodo.org/badge/DOI/10.5281/zenodo.14806581.svg\n[zenodo-link]:              https://zenodo.org/doi/10.5281/zenodo.14806581\n\n\n[diffrax-link]: https://docs.kidger.site/diffrax/\n\n\u003c!-- prettier-ignore-end --\u003e\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fgalacticdynamics%2Fdiffraxtra","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fgalacticdynamics%2Fdiffraxtra","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fgalacticdynamics%2Fdiffraxtra/lists"}