{"id":15362497,"url":"https://github.com/patrick-kidger/quax","last_synced_at":"2025-11-17T15:31:25.518Z","repository":{"id":195430825,"uuid":"655969803","full_name":"patrick-kidger/quax","owner":"patrick-kidger","description":"Multiple dispatch over abstract array types in JAX.","archived":false,"fork":false,"pushed_at":"2025-04-11T12:49:59.000Z","size":103,"stargazers_count":115,"open_issues_count":7,"forks_count":5,"subscribers_count":7,"default_branch":"main","last_synced_at":"2025-04-11T13:15:06.246Z","etag":null,"topics":["equinox","jax","multiple-dispatch","python-typing","typing"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/patrick-kidger.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":"CONTRIBUTING.md","funding":".github/FUNDING.yml","license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null},"funding":{"github":["patrick-kidger"]}},"created_at":"2023-06-20T02:19:45.000Z","updated_at":"2025-04-11T12:50:02.000Z","dependencies_parsed_at":"2023-11-11T00:22:43.058Z","dependency_job_id":"f44a678d-10fd-4ec6-a21f-eabc27020337","html_url":"https://github.com/patrick-kidger/quax","commit_stats":{"total_commits":45,"total_committers":4,"mean_commits":11.25,"dds":0.2222222222222222,"last_synced_commit":"166266f874c0706422397764f9f66ae36aeaa01e"},"previous_names":["patrick-kidger/quax"],"tags_count":5,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/patrick-kidger%2Fquax","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/patrick-kidger%2Fquax/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/patrick-kidger%2Fquax/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/patrick-kidger%2Fquax/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/patrick-kidger","download_url":"https://codeload.github.com/patrick-kidger/quax/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":248631687,"owners_count":21136556,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["equinox","jax","multiple-dispatch","python-typing","typing"],"created_at":"2024-10-01T13:01:58.244Z","updated_at":"2025-11-17T15:31:20.491Z","avatar_url":"https://github.com/patrick-kidger.png","language":"Python","readme":"\u003ch1 align=\"center\"\u003eQuax\u003c/h1\u003e\n\u003ch2 align=\"center\"\u003eJAX + multiple dispatch + custom array-ish objects\u003c/h2\u003e\n\nFor example, this can be mean overloading matrix multiplication to exploit sparsity or structure, or automatically rewriting a LoRA's matmul `(W + AB)v` into the more-efficient `Wv + ABv`.\n\nApplications include:\n\n- LoRA weight matrices\n- symbolic zeros\n- arrays with named dimensions\n- structured (e.g. tridiagonal) matrices\n- sparse arrays\n- quantised arrays\n- arrays with physical units attached\n- etc! (See the built-in `quax.examples` library for most of the above!)\n\nThis works via a custom JAX transform. Take an existing JAX program, wrap it in a `quax.quaxify`, and then pass in the custom array-ish objects. This means it will work even with existing programs, that were not written to accept such array-ish objects!\n\n_(Just like how `jax.vmap` takes a program, but reinterprets each operation as its batched version, so to will `quax.quaxify` take a program and reinterpret each operation according to what array-ish types are passed.)_\n\n## Installation\n\n```\npip install quax\n```\n\n## Documentation\n\nAvailable at https://docs.kidger.site/quax.\n\n## Example: LoRA\n\nThis example demonstrates everything you need to use the built-in `quax.examples.lora` library.\n\n```python\nimport equinox as eqx\nimport jax.random as jr\nimport quax\nimport quax.examples.lora as lora\n\n#\n# Start off with any JAX program: here, the forward pass through a linear layer.\n#\n\nkey1, key2, key3 = jr.split(jr.PRNGKey(0), 3)\nlinear = eqx.nn.Linear(10, 12, key=key1)\nvector = jr.normal(key2, (10,))\n\ndef run(model, x):\n  return model(x)\n\nrun(linear, vector)  # can call this as normal\n\n#\n# Now let's Lora-ify it.\n#\n\n# Step 1: make the weight be a LoraArray.\nlora_weight = lora.LoraArray(linear.weight, rank=2, key=key3)\nlora_linear = eqx.tree_at(lambda l: l.weight, linear, lora_weight)\n# Step 2: quaxify and call the original function. The transform will call the\n# original function, whilst looking up any multiple dispatch rules registered.\n# (In this case for doing matmuls against LoraArrays.)\nquax.quaxify(run)(lora_linear, vector)\n# Appendix: Quax includes a helper to automatically apply Step 1 to all\n# `eqx.nn.Linear` layers in a model.\nlora_linear = lora.loraify(linear, rank=2, key=key3)\n```\n\n## Work in progress!\n\nRight now, the following are not supported:\n\n- `jax.lax.scan_p`\n- `jax.custom_vjp`\n\nIt should be fairly straightforward to add support for these; open an issue or pull request. (We've already got `jax.custom_jvp`, `jax.lax.cond_p`, and `jax.lax.while_p`. :) )\n\n## See also: other libraries in the JAX ecosystem\n\n**Always useful**  \n[Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX!  \n[jaxtyping](https://github.com/patrick-kidger/jaxtyping): type annotations for shape/dtype of arrays.  \n\n**Deep learning**  \n[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.  \n[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).  \n[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).  \n\n**Scientific computing**  \n[Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.  \n[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.  \n[Lineax](https://github.com/patrick-kidger/lineax): linear solvers.  \n[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.  \n[sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy\u003c-\u003eJAX conversion; train symbolic expressions via gradient descent.  \n[PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)  \n\n**Built on Quax**  \n[Quaxed](https://github.com/GalacticDynamics/quaxed): a namespace of already-wrapped `quaxify(jnp.foo)` operations.  \n[unxt](https://github.com/GalacticDynamics/unxt): Unitful Quantities.\n\n**Awesome JAX**  \n[Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects.  \n\n## Acknowledgements\n\nSignificantly inspired by https://github.com/davisyoshida/qax, https://github.com/stanford-crfm/levanter, and `jax.experimental.sparse`.\n","funding_links":["https://github.com/sponsors/patrick-kidger"],"categories":["Libraries"],"sub_categories":[],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fpatrick-kidger%2Fquax","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fpatrick-kidger%2Fquax","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fpatrick-kidger%2Fquax/lists"}