Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/lockwo/awesome-jax
Curated list of JAX Resources and Packages
https://github.com/lockwo/awesome-jax
List: awesome-jax
autograd awesome awesome-list jax machine-learning numpy scientific-computing
Last synced: about 9 hours ago
JSON representation
Curated list of JAX Resources and Packages
- Host: GitHub
- URL: https://github.com/lockwo/awesome-jax
- Owner: lockwo
- License: apache-2.0
- Created: 2025-02-09T05:43:56.000Z (3 days ago)
- Default Branch: main
- Last Pushed: 2025-02-11T06:24:31.000Z (about 18 hours ago)
- Last Synced: 2025-02-11T07:22:19.697Z (about 17 hours ago)
- Topics: autograd, awesome, awesome-list, jax, machine-learning, numpy, scientific-computing
- Homepage: https://lockwo.github.io/awesome-jax/
- Size: 21.5 KB
- Stars: 2
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- Contributing: CONTRIBUTING.md
- License: LICENSE
Awesome Lists containing this project
README
# Awesome JAX [![Awesome](https://awesome.re/badge.svg)](https://awesome.re)[
](https://github.com/google/jax)
[JAX](https://github.com/google/jax) brings automatic differentiation and the [XLA compiler](https://github.com/openxla/xla) together through a [NumPy](https://numpy.org/)-like API for high performance machine learning research on accelerators like GPUs and TPUs.
This is a curated list of awesome JAX libraries, projects, and other resources. Contributions are welcome!
Be sure to check out our (experimental) interactive web version: https://lockwo.github.io/awesome-jax/.
Why do we need another "awesome-jax" list? Existing ones are inactive, and this is directly based on the no longer active Awesome JAX repos https://github.com/n2cholas/awesome-jax/ and https://github.com/mhlr/awesome-jax.
## Contents
- [Libraries](#libraries)
- [Models and Projects](#models-and-projects)
- [Tutorials and Blog Posts](#tutorials-and-blog-posts)
- [Community](#community)## Libraries
- Neural Network Libraries
- [Flax](https://github.com/google/flax) - Flax is a neural network library for JAX that is designed for flexibility.![]()
- [Equinox](https://github.com/patrick-kidger/equinox) - Elegant easy-to-use neural networks + scientific computing in JAX.- Reinforcement Learning Libraries
- Algorithms
- [cleanrl](https://github.com/vwxyzjn/cleanrl) - High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG).![]()
- [rlax](https://github.com/google-deepmind/rlax) - a library built on top of JAX that exposes useful building blocks for implementing reinforcement learning agents.![]()
- [purejaxrl](https://github.com/luchris429/purejaxrl) - Really Fast End-to-End Jax RL Implementations.![]()
- [Mava](https://github.com/instadeepai/Mava) - đĻ A research-friendly codebase for fast experimentation of multi-agent reinforcement learning in JAX.![]()
- Environments
- [pgx](https://github.com/sotetsuk/pgx) - Vectorized RL game environments in JAX.![]()
- [jumanji](https://github.com/instadeepai/jumanji) - đšī¸ A diverse suite of scalable reinforcement learning environments in JAX.![]()
- [gymnax](https://github.com/RobertTLange/gymnax) - RL Environments in JAX đ.![]()
- [brax](https://github.com/google/brax) - Massively parallel rigidbody physics simulation on accelerator hardware.- Natural Language Processing Libraries
- [levanter](https://github.com/stanford-crfm/levanter) - Legible, Scalable, Reproducible Foundation Models with Named Tensors and Jax.![]()
- [maxtext](https://github.com/AI-Hypercomputer/maxtext) - A simple, performant and scalable Jax LLM!![]()
- [EasyLM](https://github.com/young-geng/EasyLM) - Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax.- JAX Utilities Libraries
- [jaxtyping](https://github.com/patrick-kidger/jaxtyping) - Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays.![]()
- [chex](https://github.com/google-deepmind/chex) - a library of utilities for helping to write reliable JAX code.![]()
- [mpi4jax](https://github.com/mpi4jax/mpi4jax) - Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python âĄ.![]()
- [jax-tqdm](https://github.com/jeremiecoullon/jax-tqdm) - Add a tqdm progress bar to your JAX scans and loops.![]()
- [JAX-Toolbox](https://github.com/NVIDIA/JAX-Toolbox) - JAX Toolbox provides a public CI, Docker images for popular JAX libraries, and optimized JAX examples to simplify and enhance your JAX development experience on NVIDIA GPUs.![]()
- [penzai](https://github.com/google-deepmind/penzai) - A JAX research toolkit for building, editing, and visualizing neural networks.![]()
- [orbax](https://github.com/google/orbax) - Orbax provides common checkpointing and persistence utilities for JAX users.- Computer Vision Libraries
- [Scenic](https://github.com/google-research/scenic) - Scenic: A Jax Library for Computer Vision Research and Beyond.![]()
- [dm_pix](https://github.com/google-deepmind/dm_pix) - PIX is an image processing library in JAX, for JAX.- Distributions, Sampling, and Probabilistic Libraries
- [distreqx](https://github.com/lockwo/distreqx) - Distrax, but in equinox. Lightweight JAX library of probability distributions and bijectors.![]()
- [distrax](https://github.com/google-deepmind/distrax) - a lightweight library of probability distributions and bijectors.![]()
- [flowjax](https://github.com/danielward27/flowjax) - Distributions, bijections and normalizing flows using Equinox and JAX.![]()
- [blackjax](https://github.com/blackjax-devs/blackjax) - BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.![]()
- [bayex](https://github.com/alonfnt/bayex) - Minimal Implementation of Bayesian Optimization in JAX.![]()
- [efax](https://github.com/NeilGirdhar/efax) - Exponential families for JAX.- [GPJax](https://github.com/JaxGaussianProcesses/GPJax) - Gaussian processes in JAX.
![]()
- [tinygp](https://github.com/dfm/tinygp) - The tiniest of Gaussian Process libraries.![]()
- [Diffrax](https://github.com/patrick-kidger/diffrax) - Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.![]()
- [jax-md](https://github.com/jax-md/jax-md) - Differentiable, Hardware Accelerated, Molecular Dynamics.![]()
- [lineax](https://github.com/patrick-kidger/lineax) - Linear solvers in JAX and Equinox.![]()
- [optimistix](https://github.com/patrick-kidger/optimistix) - Nonlinear optimisation (root-finding, least squares, etc.) in JAX+Equinox.![]()
- [sympy2jax](https://github.com/patrick-kidger/sympy2jax) - Turn SymPy expressions into trainable JAX expressions.![]()
- [quax](https://github.com/patrick-kidger/quax) - Multiple dispatch over abstract array types in JAX.![]()
- [interpax](https://github.com/f0uriest/interpax) - Interpolation and function approximation with JAX.![]()
- [quadax](https://github.com/f0uriest/quadax) - Numerical quadrature with JAX.![]()
- [optax](https://github.com/google-deepmind/optax) - Optax is a gradient processing and optimization library for JAX.![]()
- [dynamax](https://github.com/probml/dynamax) - State Space Models library in JAX.![]()
- [dynamiqs](https://github.com/dynamiqs/dynamiqs) - High-performance quantum systems simulation with JAX (GPU-accelerated & differentiable solvers).![]()
- [scico](https://github.com/lanl/scico) - Scientific Computational Imaging COde.![]()
- [exojax](https://github.com/HajimeKawahara/exojax) - đ Automatic differentiable spectrum modeling of exoplanets/brown dwarfs using JAX, compatible with NumPyro and Optax/JAXopt.![]()
- [PGMax](https://github.com/google-deepmind/PGMax) - Loopy belief propagation for factor graphs on discrete variables in JAX.![]()
- [evosax](https://github.com/RobertTLange/evosax) - Evolution Strategies in JAX đĻ.![]()
- [evojax](https://github.com/google/evojax) - EvoJAX is a scalable, general purpose, hardware-accelerated neuroevolution toolkit. Built on top of the JAX library, this toolkit enables neuroevolution algorithms to work with neural networks running in parallel across multiple TPU/GPUs.![]()
- [mctx](https://github.com/google-deepmind/mctx) - Monte Carlo tree search in JAX.![]()
- [kfac-jax](https://github.com/google-deepmind/kfac-jax) - Second Order Optimization and Curvature Estimation with K-FAC in JAX.![]()
- [jwave](https://github.com/ucl-bug/jwave) - A JAX-based research framework for differentiable and parallelizable acoustic simulations, on CPU, GPUs and TPUs.![]()
- [jax_cosmo](https://github.com/DifferentiableUniverseInitiative/jax_cosmo) - A differentiable cosmology library in JAX.![]()
- [jaxlie](https://github.com/brentyi/jaxlie) - Rigid transforms + Lie groups in JAX.![]()
- [ott](https://github.com/ott-jax/ott) - Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations.![]()
- [XLB](https://github.com/Autodesk/XLB) - XLB: Accelerated Lattice Boltzmann (XLB) for Physics-based ML.![]()
- [EasyDeL](https://github.com/erfanzar/EasyDeL) - Accelerate, Optimize performance with streamlined training and serving options with JAX.![]()
- [QDax](https://github.com/adaptive-intelligent-robotics/QDax) - Accelerated Quality-Diversity.![]()
- [paxml](https://github.com/google/paxml) - Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.![]()
- [econpizza](https://github.com/gboehl/econpizza) - Solve nonlinear heterogeneous agent models.![]()
- [fedjax](https://github.com/google/fedjax) - FedJAX is a JAX-based open source library for Federated Learning simulations that emphasizes ease-of-use in research.![]()
- [neural-tangents](https://github.com/google/neural-tangents) - Fast and Easy Infinite Neural Networks in Python.![]()
- [jax-fem](https://github.com/deepmodeling/jax-fem) - Differentiable Finite Element Method with JAX.![]()
- [veros](https://github.com/team-ocean/veros) - The versatile ocean simulator, in pure Python, powered by JAX.![]()
- [JAXFLUIDS](https://github.com/tumaer/JAXFLUIDS) - Differentiable Fluid Dynamics Package.### Up and Coming Libraries
- [traceax](https://github.com/mancusolab/traceax) - Stochastic trace estimation using JAX.
![]()
- [graphax](https://github.com/jamielohoff/graphax) - Cross-Country Elimination in JAX.![]()
- [cd_dynamax](https://github.com/hd-UQ/cd_dynamax) - Extension of dynamax repo to cases with continuous-time dynamics with measurements sampled at possibly irregular discrete times. Allows generic inference of dynamical systems parameters from partial noisy observations via auto-differentiable filtering, SGD, and HMC.### Inactive Libraries
- [Haiku](https://github.com/google-deepmind/dm-haiku) - JAX-based neural network library.
![]()
- [jraph](https://github.com/google-deepmind/jraph) - A Graph Neural Network Library in Jax.![]()
- [SymJAX](https://github.com/SymJAX/SymJAX) - symbolic CPU/GPU/TPU programming.![]()
- [coax](https://github.com/coax-dev/coax) - Modular framework for Reinforcement Learning in python.![]()
- [eqxvision](https://github.com/paganpasta/eqxvision) - A Python package of computer vision models for the Equinox ecosystem.![]()
- [jaxfit](https://github.com/dipolar-quantum-gases/jaxfit) - GPU/TPU accelerated nonlinear least-squares curve fitting using JAX.![]()
- [safejax](https://github.com/alvarobartt/safejax) - Serialize JAX, Flax, Haiku, or Objax model params with đ¤`safetensors`.![]()
- [kernex](https://github.com/ASEM000/kernex) - Stencil computations in JAX.![]()
- [lorax](https://github.com/davisyoshida/lorax) - LoRA for arbitrary JAX models and functions.![]()
- [mcx](https://github.com/rlouf/mcx) - Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.![]()
- [einshape](https://github.com/google-deepmind/einshape) - DSL-based reshaping library for JAX and other frameworks.![]()
- [jax-flows](https://github.com/ChrisWaites/jax-flows) - Normalizing Flows in JAX đ.![]()
- [sklearn-jax-kernels](https://github.com/ExpectationMax/sklearn-jax-kernels) - Composable kernels for scikit-learn implemented in JAX.![]()
- [deltapv](https://github.com/romanodev/deltapv) - A photovoltaic simulator with automatic differentiation.![]()
- [cr-sparse](https://github.com/carnotresearch/cr-sparse) - Functional models and algorithms for sparse signal processing.![]()
- [flaxvision](https://github.com/rolandgvc/flaxvision) - A selection of neural network models ported from torchvision for JAX & Flax.![]()
- [imax](https://github.com/4rtemi5/imax) - Image augmentation library for Jax.![]()
- [jax-unirep](https://github.com/ElArkk/jax-unirep) - Reimplementation of the UniRep protein featurization model.![]()
- [parallax](https://github.com/srush/parallax) - Immutable Torch Modules for JAX.![]()
- [jax-resnet](https://github.com/n2cholas/jax-resnet/) - Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax).![]()
- [elegy](https://github.com/poets-ai/elegy/) - A High Level API for Deep Learning in JAX.![]()
- [objax](https://github.com/google/objax) - Objax is an open source machine learning framework that accelerates research and learning thanks to a minimalist object-oriented design and a readable code base.![]()
- [jaxrl](https://github.com/ikostrikov/jaxrl) - JAX (Flax) implementation of algorithms for Deep Reinforcement Learning with continuous action spaces.## Models and Projects
- [whisper-jax](https://github.com/sanchit-gandhi/whisper-jax) - JAX implementation of OpenAI's Whisper model for up to 70x speed-up on TPU.
![]()
- [esm2quinox](https://github.com/patrick-kidger/esm2quinox) - An implementation of ESM2 in Equinox+JAX.## Tutorials and Blog Posts
- [Learning JAX as a PyTorch developer](https://kidger.site/thoughts/torch2jax/)
- [Massively parallel MCMC with JAX](https://rlouf.github.io/post/jax-random-walk-metropolis/)
- [Achieving Over 4000x Speedups and Meta-Evolving Discoveries with PureJaxRL](https://chrislu.page/blog/meta-disco/)
- [How to add a progress bar to JAX scans and loops](https://www.jeremiecoullon.com/2021/01/29/jax_progress_bar/)
- [MCMC in JAX with benchmarks: 3 ways to write a sampler](https://www.jeremiecoullon.com/2020/11/10/mcmcjax3ways/)
- [Deterministic ADVI in JAX](https://martiningram.github.io/deterministic-advi/)
- [Exploring hyperparameter meta-loss landscapes with Jax](http://lukemetz.com/exploring-hyperparameter-meta-loss-landscapes-with-jax/)
- [Evolving Neural Networks in JAX](https://roberttlange.com/posts/2021/02/cma-es-jax/)
- [Meta-Learning in 50 Lines of JAX](https://blog.evjang.com/2019/02/maml-jax.html)
- [Implementing NeRF in JAX](https://wandb.ai/wandb/nerf-jax/reports/Implementing-NeRF-in-JAX--VmlldzoxODA2NDk2?galleryTag=jax)
- [Normalizing Flows in 100 Lines of JAX](https://blog.evjang.com/2019/07/nf-jax.html)
- [JAX vs Julia (vs PyTorch)](https://kidger.site/thoughts/jax-vs-julia/)
- [From PyTorch to JAX: towards neural net frameworks that purify stateful code](https://sjmielke.com/jax-purify.htm)
- [out of distribution detection using focal loss](http://matpalm.com/blog/ood_using_focal_loss/)
- [Differentiable Path Tracing on the GPU/TPU](https://blog.evjang.com/2019/11/jaxpt.html)
- [Getting started with JAX (MLPs, CNNs & RNNs)](https://roberttlange.com/posts/2020/03/blog-post-10/)### Videos
- [NeurIPS 2020: JAX Ecosystem Meetup](https://www.youtube.com/watch?v=iDxJxIyzSiM)
- [Introduction to JAX](https://www.youtube.com/watch?v=0mVmRHMaOJ4)
- [JAX: Accelerated Machine Learning Research | SciPy 2020 | VanderPlas](https://www.youtube.com/watch?v=z-WSrQDXkuM)
- [Bayesian Programming with JAX + NumPyro â Andy Kitchen](https://www.youtube.com/watch?v=CecuWGpoztw)## Community
- [JAX LLM Discord](https://discord.gg/CKazXcbbBm)