An open API service indexing awesome lists of open source software.

awesome-jax

Curated list of JAX Resources and Packages
https://github.com/lockwo/awesome-jax

Last synced: 5 days ago
JSON representation

  • Libraries

      • Stoix - đŸ›ī¸A research-friendly codebase for fast experimentation of single-agent reinforcement learning in JAX â€ĸ End-to-End JAX RL. <img src="https://img.shields.io/github/stars/EdanToledo/Stoix?style=social" align="center">
      • Kinetix - Reinforcement learning on general 2D physics environments in JAX. ICLR 2025 Oral. <img src="https://img.shields.io/github/stars/FLAIROx/Kinetix?style=social" align="center">
      • fdtdx - Electromagnetic FDTD Simulations in JAX. <img src="https://img.shields.io/github/stars/ymahlau/fdtdx?style=social" align="center">
      • jaxns - Probabilistic Programming and Nested sampling in JAX. <img src="https://img.shields.io/github/stars/Joshuaalbert/jaxns?style=social" align="center">
      • probdiffeq - Probabilistic solvers for differential equations in JAX. Adaptive ODE solvers with calibration, state-space model factorisations, and custom information operators. Compatible with the broader JAX scientific computing ecosystem. <img src="https://img.shields.io/github/stars/pnkraemer/probdiffeq?style=social" align="center">
      • torch2jax - Wraps PyTorch code in a JIT-compatible way for JAX. Supports automatically defining gradients for reverse-mode AutoDiff. <img src="https://img.shields.io/github/stars/rdyro/torch2jax?style=social" align="center">
      • cola - Compositional Linear Algebra. <img src="https://img.shields.io/github/stars/wilson-labs/cola?style=social" align="center">
      • JaxGCRL - Goal-Conditioned Reinforcement Learning with JAX. <img src="https://img.shields.io/github/stars/MichalBortkiewicz/JaxGCRL?style=social" align="center">
      • Equinox - Elegant easy-to-use neural networks + scientific computing in JAX. <img src="https://img.shields.io/github/stars/patrick-kidger/equinox?style=social" align="center">
      • gymnax - RL Environments in JAX 🌍. <img src="https://img.shields.io/github/stars/RobertTLange/gymnax?style=social" align="center">
      • Flax - Flax is a neural network library for JAX that is designed for flexibility. <img src="https://img.shields.io/github/stars/google/flax?style=social" align="center">
      • cleanrl - High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG). <img src="https://img.shields.io/github/stars/vwxyzjn/cleanrl?style=social" align="center">
      • rlax - a library built on top of JAX that exposes useful building blocks for implementing reinforcement learning agents. <img src="https://img.shields.io/github/stars/google-deepmind/rlax?style=social" align="center">
      • purejaxrl - Really Fast End-to-End Jax RL Implementations. <img src="https://img.shields.io/github/stars/luchris429/purejaxrl?style=social" align="center">
      • Mava - đŸĻ A research-friendly codebase for fast experimentation of multi-agent reinforcement learning in JAX. <img src="https://img.shields.io/github/stars/instadeepai/Mava?style=social" align="center">
      • pgx - Vectorized RL game environments in JAX. <img src="https://img.shields.io/github/stars/sotetsuk/pgx?style=social" align="center">
      • jumanji - đŸ•šī¸ A diverse suite of scalable reinforcement learning environments in JAX. <img src="https://img.shields.io/github/stars/instadeepai/jumanji?style=social" align="center">
      • brax - Massively parallel rigidbody physics simulation on accelerator hardware. <img src="https://img.shields.io/github/stars/google/brax?style=social" align="center">
      • levanter - Legible, Scalable, Reproducible Foundation Models with Named Tensors and Jax. <img src="https://img.shields.io/github/stars/stanford-crfm/levanter?style=social" align="center">
      • maxtext - A simple, performant and scalable Jax LLM! <img src="https://img.shields.io/github/stars/AI-Hypercomputer/maxtext?style=social" align="center">
      • Equinox - Elegant easy-to-use neural networks + scientific computing in JAX. <img src="https://img.shields.io/github/stars/patrick-kidger/equinox?style=social" align="center">
      • cleanrl - High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG). <img src="https://img.shields.io/github/stars/vwxyzjn/cleanrl?style=social" align="center">
      • rlax - a library built on top of JAX that exposes useful building blocks for implementing reinforcement learning agents. <img src="https://img.shields.io/github/stars/google-deepmind/rlax?style=social" align="center">
      • purejaxrl - Really Fast End-to-End Jax RL Implementations. <img src="https://img.shields.io/github/stars/luchris429/purejaxrl?style=social" align="center">
      • Mava - đŸĻ A research-friendly codebase for fast experimentation of multi-agent reinforcement learning in JAX. <img src="https://img.shields.io/github/stars/instadeepai/Mava?style=social" align="center">
      • pgx - Vectorized RL game environments in JAX. <img src="https://img.shields.io/github/stars/sotetsuk/pgx?style=social" align="center">
      • jumanji - đŸ•šī¸ A diverse suite of scalable reinforcement learning environments in JAX. <img src="https://img.shields.io/github/stars/instadeepai/jumanji?style=social" align="center">
      • gymnax - RL Environments in JAX 🌍. <img src="https://img.shields.io/github/stars/RobertTLange/gymnax?style=social" align="center">
      • brax - Massively parallel rigidbody physics simulation on accelerator hardware. <img src="https://img.shields.io/github/stars/google/brax?style=social" align="center">
      • levanter - Legible, Scalable, Reproducible Foundation Models with Named Tensors and Jax. <img src="https://img.shields.io/github/stars/stanford-crfm/levanter?style=social" align="center">
      • maxtext - A simple, performant and scalable Jax LLM! <img src="https://img.shields.io/github/stars/AI-Hypercomputer/maxtext?style=social" align="center">
      • EasyLM - Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax. <img src="https://img.shields.io/github/stars/young-geng/EasyLM?style=social" align="center">
      • jaxtyping - Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. <img src="https://img.shields.io/github/stars/patrick-kidger/jaxtyping?style=social" align="center">
      • chex - a library of utilities for helping to write reliable JAX code. <img src="https://img.shields.io/github/stars/google-deepmind/chex?style=social" align="center">
      • mpi4jax - Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python ⚡. <img src="https://img.shields.io/github/stars/mpi4jax/mpi4jax?style=social" align="center">
      • jax-tqdm - Add a tqdm progress bar to your JAX scans and loops. <img src="https://img.shields.io/github/stars/jeremiecoullon/jax-tqdm?style=social" align="center">
      • JAX-Toolbox - JAX Toolbox provides a public CI, Docker images for popular JAX libraries, and optimized JAX examples to simplify and enhance your JAX development experience on NVIDIA GPUs. <img src="https://img.shields.io/github/stars/NVIDIA/JAX-Toolbox?style=social" align="center">
      • penzai - A JAX research toolkit for building, editing, and visualizing neural networks. <img src="https://img.shields.io/github/stars/google-deepmind/penzai?style=social" align="center">
      • orbax - Orbax provides common checkpointing and persistence utilities for JAX users. <img src="https://img.shields.io/github/stars/google/orbax?style=social" align="center">
      • Scenic - Scenic: A Jax Library for Computer Vision Research and Beyond. <img src="https://img.shields.io/github/stars/google-research/scenic?style=social" align="center">
      • dm_pix - PIX is an image processing library in JAX, for JAX. <img src="https://img.shields.io/github/stars/google-deepmind/dm_pix?style=social" align="center">
      • distreqx - Distrax, but in equinox. Lightweight JAX library of probability distributions and bijectors. <img src="https://img.shields.io/github/stars/lockwo/distreqx?style=social" align="center">
      • distrax - a lightweight library of probability distributions and bijectors. <img src="https://img.shields.io/github/stars/google-deepmind/distrax?style=social" align="center">
      • flowjax - Distributions, bijections and normalizing flows using Equinox and JAX. <img src="https://img.shields.io/github/stars/danielward27/flowjax?style=social" align="center">
      • blackjax - BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity. <img src="https://img.shields.io/github/stars/blackjax-devs/blackjax?style=social" align="center">
      • EasyLM - Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax. <img src="https://img.shields.io/github/stars/young-geng/EasyLM?style=social" align="center">
      • jaxtyping - Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. <img src="https://img.shields.io/github/stars/patrick-kidger/jaxtyping?style=social" align="center">
      • chex - a library of utilities for helping to write reliable JAX code. <img src="https://img.shields.io/github/stars/google-deepmind/chex?style=social" align="center">
      • mpi4jax - Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python ⚡. <img src="https://img.shields.io/github/stars/mpi4jax/mpi4jax?style=social" align="center">
      • jax-tqdm - Add a tqdm progress bar to your JAX scans and loops. <img src="https://img.shields.io/github/stars/jeremiecoullon/jax-tqdm?style=social" align="center">
      • JAX-Toolbox - JAX Toolbox provides a public CI, Docker images for popular JAX libraries, and optimized JAX examples to simplify and enhance your JAX development experience on NVIDIA GPUs. <img src="https://img.shields.io/github/stars/NVIDIA/JAX-Toolbox?style=social" align="center">
      • penzai - A JAX research toolkit for building, editing, and visualizing neural networks. <img src="https://img.shields.io/github/stars/google-deepmind/penzai?style=social" align="center">
      • orbax - Orbax provides common checkpointing and persistence utilities for JAX users. <img src="https://img.shields.io/github/stars/google/orbax?style=social" align="center">
      • Scenic - Scenic: A Jax Library for Computer Vision Research and Beyond. <img src="https://img.shields.io/github/stars/google-research/scenic?style=social" align="center">
      • dm_pix - PIX is an image processing library in JAX, for JAX. <img src="https://img.shields.io/github/stars/google-deepmind/dm_pix?style=social" align="center">
      • distreqx - Distrax, but in equinox. Lightweight JAX library of probability distributions and bijectors. <img src="https://img.shields.io/github/stars/lockwo/distreqx?style=social" align="center">
      • distrax - a lightweight library of probability distributions and bijectors. <img src="https://img.shields.io/github/stars/google-deepmind/distrax?style=social" align="center">
      • flowjax - Distributions, bijections and normalizing flows using Equinox and JAX. <img src="https://img.shields.io/github/stars/danielward27/flowjax?style=social" align="center">
      • blackjax - BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity. <img src="https://img.shields.io/github/stars/blackjax-devs/blackjax?style=social" align="center">
      • bayex - Minimal Implementation of Bayesian Optimization in JAX. <img src="https://img.shields.io/github/stars/alonfnt/bayex?style=social" align="center">
      • efax - Exponential families for JAX. <img src="https://img.shields.io/github/stars/NeilGirdhar/efax?style=social" align="center">
      • GPJax - Gaussian processes in JAX. <img src="https://img.shields.io/github/stars/JaxGaussianProcesses/GPJax?style=social" align="center">
      • tinygp - The tiniest of Gaussian Process libraries. <img src="https://img.shields.io/github/stars/dfm/tinygp?style=social" align="center">
      • Diffrax - Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. <img src="https://img.shields.io/github/stars/patrick-kidger/diffrax?style=social" align="center">
      • jax-md - Differentiable, Hardware Accelerated, Molecular Dynamics. <img src="https://img.shields.io/github/stars/jax-md/jax-md?style=social" align="center">
      • lineax - Linear solvers in JAX and Equinox. <img src="https://img.shields.io/github/stars/patrick-kidger/lineax?style=social" align="center">
      • optimistix - Nonlinear optimisation (root-finding, least squares, etc.) in JAX+Equinox. <img src="https://img.shields.io/github/stars/patrick-kidger/optimistix?style=social" align="center">
      • sympy2jax - Turn SymPy expressions into trainable JAX expressions. <img src="https://img.shields.io/github/stars/patrick-kidger/sympy2jax?style=social" align="center">
      • quax - Multiple dispatch over abstract array types in JAX. <img src="https://img.shields.io/github/stars/patrick-kidger/quax?style=social" align="center">
      • interpax - Interpolation and function approximation with JAX. <img src="https://img.shields.io/github/stars/f0uriest/interpax?style=social" align="center">
      • quadax - Numerical quadrature with JAX. <img src="https://img.shields.io/github/stars/f0uriest/quadax?style=social" align="center">
      • optax - Optax is a gradient processing and optimization library for JAX. <img src="https://img.shields.io/github/stars/google-deepmind/optax?style=social" align="center">
      • dynamax - State Space Models library in JAX. <img src="https://img.shields.io/github/stars/probml/dynamax?style=social" align="center">
      • dynamiqs - High-performance quantum systems simulation with JAX (GPU-accelerated & differentiable solvers). <img src="https://img.shields.io/github/stars/dynamiqs/dynamiqs?style=social" align="center">
      • scico - Scientific Computational Imaging COde. <img src="https://img.shields.io/github/stars/lanl/scico?style=social" align="center">
      • exojax - 🐈 Automatic differentiable spectrum modeling of exoplanets/brown dwarfs using JAX, compatible with NumPyro and Optax/JAXopt. <img src="https://img.shields.io/github/stars/HajimeKawahara/exojax?style=social" align="center">
      • PGMax - Loopy belief propagation for factor graphs on discrete variables in JAX. <img src="https://img.shields.io/github/stars/google-deepmind/PGMax?style=social" align="center">
      • evosax - Evolution Strategies in JAX đŸĻŽ. <img src="https://img.shields.io/github/stars/RobertTLange/evosax?style=social" align="center">
      • evojax - EvoJAX is a scalable, general purpose, hardware-accelerated neuroevolution toolkit. Built on top of the JAX library, this toolkit enables neuroevolution algorithms to work with neural networks running in parallel across multiple TPU/GPUs. <img src="https://img.shields.io/github/stars/google/evojax?style=social" align="center">
      • mctx - Monte Carlo tree search in JAX. <img src="https://img.shields.io/github/stars/google-deepmind/mctx?style=social" align="center">
      • kfac-jax - Second Order Optimization and Curvature Estimation with K-FAC in JAX. <img src="https://img.shields.io/github/stars/google-deepmind/kfac-jax?style=social" align="center">
      • jwave - A JAX-based research framework for differentiable and parallelizable acoustic simulations, on CPU, GPUs and TPUs. <img src="https://img.shields.io/github/stars/ucl-bug/jwave?style=social" align="center">
      • jax_cosmo - A differentiable cosmology library in JAX. <img src="https://img.shields.io/github/stars/DifferentiableUniverseInitiative/jax_cosmo?style=social" align="center">
      • jaxlie - Rigid transforms + Lie groups in JAX. <img src="https://img.shields.io/github/stars/brentyi/jaxlie?style=social" align="center">
      • bayex - Minimal Implementation of Bayesian Optimization in JAX. <img src="https://img.shields.io/github/stars/alonfnt/bayex?style=social" align="center">
      • efax - Exponential families for JAX. <img src="https://img.shields.io/github/stars/NeilGirdhar/efax?style=social" align="center">
      • GPJax - Gaussian processes in JAX. <img src="https://img.shields.io/github/stars/JaxGaussianProcesses/GPJax?style=social" align="center">
      • tinygp - The tiniest of Gaussian Process libraries. <img src="https://img.shields.io/github/stars/dfm/tinygp?style=social" align="center">
      • Diffrax - Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. <img src="https://img.shields.io/github/stars/patrick-kidger/diffrax?style=social" align="center">
      • jax-md - Differentiable, Hardware Accelerated, Molecular Dynamics. <img src="https://img.shields.io/github/stars/jax-md/jax-md?style=social" align="center">
      • lineax - Linear solvers in JAX and Equinox. <img src="https://img.shields.io/github/stars/patrick-kidger/lineax?style=social" align="center">
      • optimistix - Nonlinear optimisation (root-finding, least squares, etc.) in JAX+Equinox. <img src="https://img.shields.io/github/stars/patrick-kidger/optimistix?style=social" align="center">
      • sympy2jax - Turn SymPy expressions into trainable JAX expressions. <img src="https://img.shields.io/github/stars/patrick-kidger/sympy2jax?style=social" align="center">
      • quax - Multiple dispatch over abstract array types in JAX. <img src="https://img.shields.io/github/stars/patrick-kidger/quax?style=social" align="center">
      • interpax - Interpolation and function approximation with JAX. <img src="https://img.shields.io/github/stars/f0uriest/interpax?style=social" align="center">
      • quadax - Numerical quadrature with JAX. <img src="https://img.shields.io/github/stars/f0uriest/quadax?style=social" align="center">
      • optax - Optax is a gradient processing and optimization library for JAX. <img src="https://img.shields.io/github/stars/google-deepmind/optax?style=social" align="center">
      • dynamax - State Space Models library in JAX. <img src="https://img.shields.io/github/stars/probml/dynamax?style=social" align="center">
      • dynamiqs - High-performance quantum systems simulation with JAX (GPU-accelerated & differentiable solvers). <img src="https://img.shields.io/github/stars/dynamiqs/dynamiqs?style=social" align="center">
      • scico - Scientific Computational Imaging COde. <img src="https://img.shields.io/github/stars/lanl/scico?style=social" align="center">
      • exojax - 🐈 Automatic differentiable spectrum modeling of exoplanets/brown dwarfs using JAX, compatible with NumPyro and Optax/JAXopt. <img src="https://img.shields.io/github/stars/HajimeKawahara/exojax?style=social" align="center">
      • PGMax - Loopy belief propagation for factor graphs on discrete variables in JAX. <img src="https://img.shields.io/github/stars/google-deepmind/PGMax?style=social" align="center">
      • evosax - Evolution Strategies in JAX đŸĻŽ. <img src="https://img.shields.io/github/stars/RobertTLange/evosax?style=social" align="center">
      • evojax - EvoJAX is a scalable, general purpose, hardware-accelerated neuroevolution toolkit. Built on top of the JAX library, this toolkit enables neuroevolution algorithms to work with neural networks running in parallel across multiple TPU/GPUs. <img src="https://img.shields.io/github/stars/google/evojax?style=social" align="center">
      • mctx - Monte Carlo tree search in JAX. <img src="https://img.shields.io/github/stars/google-deepmind/mctx?style=social" align="center">
      • kfac-jax - Second Order Optimization and Curvature Estimation with K-FAC in JAX. <img src="https://img.shields.io/github/stars/google-deepmind/kfac-jax?style=social" align="center">
      • jwave - A JAX-based research framework for differentiable and parallelizable acoustic simulations, on CPU, GPUs and TPUs. <img src="https://img.shields.io/github/stars/ucl-bug/jwave?style=social" align="center">
      • jax_cosmo - A differentiable cosmology library in JAX. <img src="https://img.shields.io/github/stars/DifferentiableUniverseInitiative/jax_cosmo?style=social" align="center">
      • jaxlie - Rigid transforms + Lie groups in JAX. <img src="https://img.shields.io/github/stars/brentyi/jaxlie?style=social" align="center">
      • ott - Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations. <img src="https://img.shields.io/github/stars/ott-jax/ott?style=social" align="center">
      • EasyDeL - Accelerate, Optimize performance with streamlined training and serving options with JAX. <img src="https://img.shields.io/github/stars/erfanzar/EasyDeL?style=social" align="center">
      • QDax - Accelerated Quality-Diversity. <img src="https://img.shields.io/github/stars/adaptive-intelligent-robotics/QDax?style=social" align="center">
      • ott - Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations. <img src="https://img.shields.io/github/stars/ott-jax/ott?style=social" align="center">
      • XLB - XLB: Accelerated Lattice Boltzmann (XLB) for Physics-based ML. <img src="https://img.shields.io/github/stars/Autodesk/XLB?style=social" align="center">
      • EasyDeL - Accelerate, Optimize performance with streamlined training and serving options with JAX. <img src="https://img.shields.io/github/stars/erfanzar/EasyDeL?style=social" align="center">
      • QDax - Accelerated Quality-Diversity. <img src="https://img.shields.io/github/stars/adaptive-intelligent-robotics/QDax?style=social" align="center">
      • paxml - Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates. <img src="https://img.shields.io/github/stars/google/paxml?style=social" align="center">
      • econpizza - Solve nonlinear heterogeneous agent models. <img src="https://img.shields.io/github/stars/gboehl/econpizza?style=social" align="center">
      • fedjax - FedJAX is a JAX-based open source library for Federated Learning simulations that emphasizes ease-of-use in research. <img src="https://img.shields.io/github/stars/google/fedjax?style=social" align="center">
      • neural-tangents - Fast and Easy Infinite Neural Networks in Python. <img src="https://img.shields.io/github/stars/google/neural-tangents?style=social" align="center">
      • jax-fem - Differentiable Finite Element Method with JAX. <img src="https://img.shields.io/github/stars/deepmodeling/jax-fem?style=social" align="center">
      • veros - The versatile ocean simulator, in pure Python, powered by JAX. <img src="https://img.shields.io/github/stars/team-ocean/veros?style=social" align="center">
      • JAXFLUIDS - Differentiable Fluid Dynamics Package. <img src="https://img.shields.io/github/stars/tumaer/JAXFLUIDS?style=social" align="center">
      • paxml - Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates. <img src="https://img.shields.io/github/stars/google/paxml?style=social" align="center">
      • econpizza - Solve nonlinear heterogeneous agent models. <img src="https://img.shields.io/github/stars/gboehl/econpizza?style=social" align="center">
      • fedjax - FedJAX is a JAX-based open source library for Federated Learning simulations that emphasizes ease-of-use in research. <img src="https://img.shields.io/github/stars/google/fedjax?style=social" align="center">
      • neural-tangents - Fast and Easy Infinite Neural Networks in Python. <img src="https://img.shields.io/github/stars/google/neural-tangents?style=social" align="center">
      • jax-fem - Differentiable Finite Element Method with JAX. <img src="https://img.shields.io/github/stars/deepmodeling/jax-fem?style=social" align="center">
      • veros - The versatile ocean simulator, in pure Python, powered by JAX. <img src="https://img.shields.io/github/stars/team-ocean/veros?style=social" align="center">
      • JAXFLUIDS - Differentiable Fluid Dynamics Package. <img src="https://img.shields.io/github/stars/tumaer/JAXFLUIDS?style=social" align="center">
      • JaxMARL - Multi-Agent Reinforcement Learning with JAX. <img src="https://img.shields.io/github/stars/FLAIROx/JaxMARL?style=social" align="center">
      • craftax - (Crafter + NetHack) in JAX. ICML 2024 Spotlight. <img src="https://img.shields.io/github/stars/MichaelTMatthews/Craftax?style=social" align="center">
      • navix - Accelerated minigrid environments with JAX. <img src="https://img.shields.io/github/stars/epignatelli/navix?style=social" align="center">
      • klujax - Solve sparse linear systems in JAX using the KLU algorithm. <img src="https://img.shields.io/github/stars/flaport/klujax?style=social" align="center">
      • coreax - A library for coreset algorithms, written in Jax for fast execution and GPU support. <img src="https://img.shields.io/github/stars/gchq/coreax?style=social" align="center">
      • Jaxley - Differentiable neuron simulations with biophysical detail on CPU, GPU, or TPU. <img src="https://img.shields.io/github/stars/jaxleyverse/jaxley?style=social" align="center">
    • Up and Coming Libraries

      • traceax - Stochastic trace estimation using JAX. <img src="https://img.shields.io/github/stars/mancusolab/traceax?style=social" align="center">
      • graphax - Cross-Country Elimination in JAX. <img src="https://img.shields.io/github/stars/jamielohoff/graphax?style=social" align="center">
      • cd_dynamax - Extension of dynamax repo to cases with continuous-time dynamics with measurements sampled at possibly irregular discrete times. Allows generic inference of dynamical systems parameters from partial noisy observations via auto-differentiable filtering, SGD, and HMC. <img src="https://img.shields.io/github/stars/hd-UQ/cd_dynamax?style=social" align="center">
      • traceax - Stochastic trace estimation using JAX. <img src="https://img.shields.io/github/stars/mancusolab/traceax?style=social" align="center">
      • graphax - Cross-Country Elimination in JAX. <img src="https://img.shields.io/github/stars/jamielohoff/graphax?style=social" align="center">
      • cd_dynamax - Extension of dynamax repo to cases with continuous-time dynamics with measurements sampled at possibly irregular discrete times. Allows generic inference of dynamical systems parameters from partial noisy observations via auto-differentiable filtering, SGD, and HMC. <img src="https://img.shields.io/github/stars/hd-UQ/cd_dynamax?style=social" align="center">
    • Inactive Libraries

      • Haiku - JAX-based neural network library. <img src="https://img.shields.io/github/stars/google-deepmind/dm-haiku?style=social" align="center">
      • Haiku - JAX-based neural network library. <img src="https://img.shields.io/github/stars/google-deepmind/dm-haiku?style=social" align="center">
      • jraph - A Graph Neural Network Library in Jax. <img src="https://img.shields.io/github/stars/google-deepmind/jraph?style=social" align="center">
      • SymJAX - symbolic CPU/GPU/TPU programming. <img src="https://img.shields.io/github/stars/SymJAX/SymJAX?style=social" align="center">
      • coax - Modular framework for Reinforcement Learning in python. <img src="https://img.shields.io/github/stars/coax-dev/coax?style=social" align="center">
      • eqxvision - A Python package of computer vision models for the Equinox ecosystem. <img src="https://img.shields.io/github/stars/paganpasta/eqxvision?style=social" align="center">
      • jaxfit - GPU/TPU accelerated nonlinear least-squares curve fitting using JAX. <img src="https://img.shields.io/github/stars/dipolar-quantum-gases/jaxfit?style=social" align="center">
      • safejax - Serialize JAX, Flax, Haiku, or Objax model params with 🤗`safetensors`. <img src="https://img.shields.io/github/stars/alvarobartt/safejax?style=social" align="center">
      • kernex - Stencil computations in JAX. <img src="https://img.shields.io/github/stars/ASEM000/kernex?style=social" align="center">
      • lorax - LoRA for arbitrary JAX models and functions. <img src="https://img.shields.io/github/stars/davisyoshida/lorax?style=social" align="center">
      • mcx - Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX. <img src="https://img.shields.io/github/stars/rlouf/mcx?style=social" align="center">
      • einshape - DSL-based reshaping library for JAX and other frameworks. <img src="https://img.shields.io/github/stars/google-deepmind/einshape?style=social" align="center">
      • jax-flows - Normalizing Flows in JAX 🌊. <img src="https://img.shields.io/github/stars/ChrisWaites/jax-flows?style=social" align="center">
      • sklearn-jax-kernels - Composable kernels for scikit-learn implemented in JAX. <img src="https://img.shields.io/github/stars/ExpectationMax/sklearn-jax-kernels?style=social" align="center">
      • deltapv - A photovoltaic simulator with automatic differentiation. <img src="https://img.shields.io/github/stars/romanodev/deltapv?style=social" align="center">
      • jraph - A Graph Neural Network Library in Jax. <img src="https://img.shields.io/github/stars/google-deepmind/jraph?style=social" align="center">
      • SymJAX - symbolic CPU/GPU/TPU programming. <img src="https://img.shields.io/github/stars/SymJAX/SymJAX?style=social" align="center">
      • coax - Modular framework for Reinforcement Learning in python. <img src="https://img.shields.io/github/stars/coax-dev/coax?style=social" align="center">
      • eqxvision - A Python package of computer vision models for the Equinox ecosystem. <img src="https://img.shields.io/github/stars/paganpasta/eqxvision?style=social" align="center">
      • jaxfit - GPU/TPU accelerated nonlinear least-squares curve fitting using JAX. <img src="https://img.shields.io/github/stars/dipolar-quantum-gases/jaxfit?style=social" align="center">
      • safejax - Serialize JAX, Flax, Haiku, or Objax model params with 🤗`safetensors`. <img src="https://img.shields.io/github/stars/alvarobartt/safejax?style=social" align="center">
      • kernex - Stencil computations in JAX. <img src="https://img.shields.io/github/stars/ASEM000/kernex?style=social" align="center">
      • lorax - LoRA for arbitrary JAX models and functions. <img src="https://img.shields.io/github/stars/davisyoshida/lorax?style=social" align="center">
      • mcx - Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX. <img src="https://img.shields.io/github/stars/rlouf/mcx?style=social" align="center">
      • einshape - DSL-based reshaping library for JAX and other frameworks. <img src="https://img.shields.io/github/stars/google-deepmind/einshape?style=social" align="center">
      • jax-flows - Normalizing Flows in JAX 🌊. <img src="https://img.shields.io/github/stars/ChrisWaites/jax-flows?style=social" align="center">
      • sklearn-jax-kernels - Composable kernels for scikit-learn implemented in JAX. <img src="https://img.shields.io/github/stars/ExpectationMax/sklearn-jax-kernels?style=social" align="center">
      • deltapv - A photovoltaic simulator with automatic differentiation. <img src="https://img.shields.io/github/stars/romanodev/deltapv?style=social" align="center">
      • cr-sparse - Functional models and algorithms for sparse signal processing. <img src="https://img.shields.io/github/stars/carnotresearch/cr-sparse?style=social" align="center">
      • flaxvision - A selection of neural network models ported from torchvision for JAX & Flax. <img src="https://img.shields.io/github/stars/rolandgvc/flaxvision?style=social" align="center">
      • imax - Image augmentation library for Jax. <img src="https://img.shields.io/github/stars/4rtemi5/imax?style=social" align="center">
      • jax-unirep - Reimplementation of the UniRep protein featurization model. <img src="https://img.shields.io/github/stars/ElArkk/jax-unirep?style=social" align="center">
      • parallax - Immutable Torch Modules for JAX. <img src="https://img.shields.io/github/stars/srush/parallax?style=social" align="center">
      • elegy - A High Level API for Deep Learning in JAX. <img src="https://img.shields.io/github/stars/poets-ai/elegy?style=social" align="center">
      • objax - Objax is an open source machine learning framework that accelerates research and learning thanks to a minimalist object-oriented design and a readable code base. <img src="https://img.shields.io/github/stars/google/objax?style=social" align="center">
      • jaxrl - JAX (Flax) implementation of algorithms for Deep Reinforcement Learning with continuous action spaces. <img src="https://img.shields.io/github/stars/ikostrikov/jaxrl?style=social" align="center">
      • cr-sparse - Functional models and algorithms for sparse signal processing. <img src="https://img.shields.io/github/stars/carnotresearch/cr-sparse?style=social" align="center">
      • flaxvision - A selection of neural network models ported from torchvision for JAX & Flax. <img src="https://img.shields.io/github/stars/rolandgvc/flaxvision?style=social" align="center">
      • imax - Image augmentation library for Jax. <img src="https://img.shields.io/github/stars/4rtemi5/imax?style=social" align="center">
      • jax-unirep - Reimplementation of the UniRep protein featurization model. <img src="https://img.shields.io/github/stars/ElArkk/jax-unirep?style=social" align="center">
      • parallax - Immutable Torch Modules for JAX. <img src="https://img.shields.io/github/stars/srush/parallax?style=social" align="center">
      • elegy - A High Level API for Deep Learning in JAX. <img src="https://img.shields.io/github/stars/poets-ai/elegy?style=social" align="center">
      • objax - Objax is an open source machine learning framework that accelerates research and learning thanks to a minimalist object-oriented design and a readable code base. <img src="https://img.shields.io/github/stars/google/objax?style=social" align="center">
      • jaxrl - JAX (Flax) implementation of algorithms for Deep Reinforcement Learning with continuous action spaces. <img src="https://img.shields.io/github/stars/ikostrikov/jaxrl?style=social" align="center">
      • jax-resnet - Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax). <img src="https://img.shields.io/github/stars/n2cholas/jax-resnet?style=social" align="center">
  • Tutorials and Blog Posts

  • Models and Projects

    • Inactive Libraries

      • whisper-jax - JAX implementation of OpenAI's Whisper model for up to 70x speed-up on TPU. <img src="https://img.shields.io/github/stars/sanchit-gandhi/whisper-jax?style=social" align="center">
      • esm2quinox - An implementation of ESM2 in Equinox+JAX. <img src="https://img.shields.io/github/stars/patrick-kidger/esm2quinox?style=social" align="center">
      • whisper-jax - JAX implementation of OpenAI's Whisper model for up to 70x speed-up on TPU. <img src="https://img.shields.io/github/stars/sanchit-gandhi/whisper-jax?style=social" align="center">
      • esm2quinox - An implementation of ESM2 in Equinox+JAX. <img src="https://img.shields.io/github/stars/patrick-kidger/esm2quinox?style=social" align="center">