awesome-jax
Curated list of JAX Resources and Packages
https://github.com/lockwo/awesome-jax
Last synced: 5 days ago
JSON representation
-
Libraries
-
- Stoix - đī¸A research-friendly codebase for fast experimentation of single-agent reinforcement learning in JAX âĸ End-to-End JAX RL. <img src="https://img.shields.io/github/stars/EdanToledo/Stoix?style=social" align="center">
- Kinetix - Reinforcement learning on general 2D physics environments in JAX. ICLR 2025 Oral. <img src="https://img.shields.io/github/stars/FLAIROx/Kinetix?style=social" align="center">
- fdtdx - Electromagnetic FDTD Simulations in JAX. <img src="https://img.shields.io/github/stars/ymahlau/fdtdx?style=social" align="center">
- jaxns - Probabilistic Programming and Nested sampling in JAX. <img src="https://img.shields.io/github/stars/Joshuaalbert/jaxns?style=social" align="center">
- probdiffeq - Probabilistic solvers for differential equations in JAX. Adaptive ODE solvers with calibration, state-space model factorisations, and custom information operators. Compatible with the broader JAX scientific computing ecosystem. <img src="https://img.shields.io/github/stars/pnkraemer/probdiffeq?style=social" align="center">
- torch2jax - Wraps PyTorch code in a JIT-compatible way for JAX. Supports automatically defining gradients for reverse-mode AutoDiff. <img src="https://img.shields.io/github/stars/rdyro/torch2jax?style=social" align="center">
- cola - Compositional Linear Algebra. <img src="https://img.shields.io/github/stars/wilson-labs/cola?style=social" align="center">
- JaxGCRL - Goal-Conditioned Reinforcement Learning with JAX. <img src="https://img.shields.io/github/stars/MichalBortkiewicz/JaxGCRL?style=social" align="center">
- Equinox - Elegant easy-to-use neural networks + scientific computing in JAX. <img src="https://img.shields.io/github/stars/patrick-kidger/equinox?style=social" align="center">
- gymnax - RL Environments in JAX đ. <img src="https://img.shields.io/github/stars/RobertTLange/gymnax?style=social" align="center">
- Flax - Flax is a neural network library for JAX that is designed for flexibility. <img src="https://img.shields.io/github/stars/google/flax?style=social" align="center">
- cleanrl - High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG). <img src="https://img.shields.io/github/stars/vwxyzjn/cleanrl?style=social" align="center">
- rlax - a library built on top of JAX that exposes useful building blocks for implementing reinforcement learning agents. <img src="https://img.shields.io/github/stars/google-deepmind/rlax?style=social" align="center">
- purejaxrl - Really Fast End-to-End Jax RL Implementations. <img src="https://img.shields.io/github/stars/luchris429/purejaxrl?style=social" align="center">
- Mava - đĻ A research-friendly codebase for fast experimentation of multi-agent reinforcement learning in JAX. <img src="https://img.shields.io/github/stars/instadeepai/Mava?style=social" align="center">
- pgx - Vectorized RL game environments in JAX. <img src="https://img.shields.io/github/stars/sotetsuk/pgx?style=social" align="center">
- jumanji - đšī¸ A diverse suite of scalable reinforcement learning environments in JAX. <img src="https://img.shields.io/github/stars/instadeepai/jumanji?style=social" align="center">
- brax - Massively parallel rigidbody physics simulation on accelerator hardware. <img src="https://img.shields.io/github/stars/google/brax?style=social" align="center">
- levanter - Legible, Scalable, Reproducible Foundation Models with Named Tensors and Jax. <img src="https://img.shields.io/github/stars/stanford-crfm/levanter?style=social" align="center">
- maxtext - A simple, performant and scalable Jax LLM! <img src="https://img.shields.io/github/stars/AI-Hypercomputer/maxtext?style=social" align="center">
- Equinox - Elegant easy-to-use neural networks + scientific computing in JAX. <img src="https://img.shields.io/github/stars/patrick-kidger/equinox?style=social" align="center">
- cleanrl - High-quality single file implementation of Deep Reinforcement Learning algorithms with research-friendly features (PPO, DQN, C51, DDPG, TD3, SAC, PPG). <img src="https://img.shields.io/github/stars/vwxyzjn/cleanrl?style=social" align="center">
- rlax - a library built on top of JAX that exposes useful building blocks for implementing reinforcement learning agents. <img src="https://img.shields.io/github/stars/google-deepmind/rlax?style=social" align="center">
- purejaxrl - Really Fast End-to-End Jax RL Implementations. <img src="https://img.shields.io/github/stars/luchris429/purejaxrl?style=social" align="center">
- Mava - đĻ A research-friendly codebase for fast experimentation of multi-agent reinforcement learning in JAX. <img src="https://img.shields.io/github/stars/instadeepai/Mava?style=social" align="center">
- pgx - Vectorized RL game environments in JAX. <img src="https://img.shields.io/github/stars/sotetsuk/pgx?style=social" align="center">
- jumanji - đšī¸ A diverse suite of scalable reinforcement learning environments in JAX. <img src="https://img.shields.io/github/stars/instadeepai/jumanji?style=social" align="center">
- gymnax - RL Environments in JAX đ. <img src="https://img.shields.io/github/stars/RobertTLange/gymnax?style=social" align="center">
- brax - Massively parallel rigidbody physics simulation on accelerator hardware. <img src="https://img.shields.io/github/stars/google/brax?style=social" align="center">
- levanter - Legible, Scalable, Reproducible Foundation Models with Named Tensors and Jax. <img src="https://img.shields.io/github/stars/stanford-crfm/levanter?style=social" align="center">
- maxtext - A simple, performant and scalable Jax LLM! <img src="https://img.shields.io/github/stars/AI-Hypercomputer/maxtext?style=social" align="center">
- EasyLM - Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax. <img src="https://img.shields.io/github/stars/young-geng/EasyLM?style=social" align="center">
- jaxtyping - Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. <img src="https://img.shields.io/github/stars/patrick-kidger/jaxtyping?style=social" align="center">
- chex - a library of utilities for helping to write reliable JAX code. <img src="https://img.shields.io/github/stars/google-deepmind/chex?style=social" align="center">
- mpi4jax - Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python âĄ. <img src="https://img.shields.io/github/stars/mpi4jax/mpi4jax?style=social" align="center">
- jax-tqdm - Add a tqdm progress bar to your JAX scans and loops. <img src="https://img.shields.io/github/stars/jeremiecoullon/jax-tqdm?style=social" align="center">
- JAX-Toolbox - JAX Toolbox provides a public CI, Docker images for popular JAX libraries, and optimized JAX examples to simplify and enhance your JAX development experience on NVIDIA GPUs. <img src="https://img.shields.io/github/stars/NVIDIA/JAX-Toolbox?style=social" align="center">
- penzai - A JAX research toolkit for building, editing, and visualizing neural networks. <img src="https://img.shields.io/github/stars/google-deepmind/penzai?style=social" align="center">
- orbax - Orbax provides common checkpointing and persistence utilities for JAX users. <img src="https://img.shields.io/github/stars/google/orbax?style=social" align="center">
- Scenic - Scenic: A Jax Library for Computer Vision Research and Beyond. <img src="https://img.shields.io/github/stars/google-research/scenic?style=social" align="center">
- dm_pix - PIX is an image processing library in JAX, for JAX. <img src="https://img.shields.io/github/stars/google-deepmind/dm_pix?style=social" align="center">
- distreqx - Distrax, but in equinox. Lightweight JAX library of probability distributions and bijectors. <img src="https://img.shields.io/github/stars/lockwo/distreqx?style=social" align="center">
- distrax - a lightweight library of probability distributions and bijectors. <img src="https://img.shields.io/github/stars/google-deepmind/distrax?style=social" align="center">
- flowjax - Distributions, bijections and normalizing flows using Equinox and JAX. <img src="https://img.shields.io/github/stars/danielward27/flowjax?style=social" align="center">
- blackjax - BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity. <img src="https://img.shields.io/github/stars/blackjax-devs/blackjax?style=social" align="center">
- EasyLM - Large language models (LLMs) made easy, EasyLM is a one stop solution for pre-training, finetuning, evaluating and serving LLMs in JAX/Flax. <img src="https://img.shields.io/github/stars/young-geng/EasyLM?style=social" align="center">
- jaxtyping - Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. <img src="https://img.shields.io/github/stars/patrick-kidger/jaxtyping?style=social" align="center">
- chex - a library of utilities for helping to write reliable JAX code. <img src="https://img.shields.io/github/stars/google-deepmind/chex?style=social" align="center">
- mpi4jax - Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python âĄ. <img src="https://img.shields.io/github/stars/mpi4jax/mpi4jax?style=social" align="center">
- jax-tqdm - Add a tqdm progress bar to your JAX scans and loops. <img src="https://img.shields.io/github/stars/jeremiecoullon/jax-tqdm?style=social" align="center">
- JAX-Toolbox - JAX Toolbox provides a public CI, Docker images for popular JAX libraries, and optimized JAX examples to simplify and enhance your JAX development experience on NVIDIA GPUs. <img src="https://img.shields.io/github/stars/NVIDIA/JAX-Toolbox?style=social" align="center">
- penzai - A JAX research toolkit for building, editing, and visualizing neural networks. <img src="https://img.shields.io/github/stars/google-deepmind/penzai?style=social" align="center">
- orbax - Orbax provides common checkpointing and persistence utilities for JAX users. <img src="https://img.shields.io/github/stars/google/orbax?style=social" align="center">
- Scenic - Scenic: A Jax Library for Computer Vision Research and Beyond. <img src="https://img.shields.io/github/stars/google-research/scenic?style=social" align="center">
- dm_pix - PIX is an image processing library in JAX, for JAX. <img src="https://img.shields.io/github/stars/google-deepmind/dm_pix?style=social" align="center">
- distreqx - Distrax, but in equinox. Lightweight JAX library of probability distributions and bijectors. <img src="https://img.shields.io/github/stars/lockwo/distreqx?style=social" align="center">
- distrax - a lightweight library of probability distributions and bijectors. <img src="https://img.shields.io/github/stars/google-deepmind/distrax?style=social" align="center">
- flowjax - Distributions, bijections and normalizing flows using Equinox and JAX. <img src="https://img.shields.io/github/stars/danielward27/flowjax?style=social" align="center">
- blackjax - BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity. <img src="https://img.shields.io/github/stars/blackjax-devs/blackjax?style=social" align="center">
- bayex - Minimal Implementation of Bayesian Optimization in JAX. <img src="https://img.shields.io/github/stars/alonfnt/bayex?style=social" align="center">
- efax - Exponential families for JAX. <img src="https://img.shields.io/github/stars/NeilGirdhar/efax?style=social" align="center">
- GPJax - Gaussian processes in JAX. <img src="https://img.shields.io/github/stars/JaxGaussianProcesses/GPJax?style=social" align="center">
- tinygp - The tiniest of Gaussian Process libraries. <img src="https://img.shields.io/github/stars/dfm/tinygp?style=social" align="center">
- Diffrax - Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. <img src="https://img.shields.io/github/stars/patrick-kidger/diffrax?style=social" align="center">
- jax-md - Differentiable, Hardware Accelerated, Molecular Dynamics. <img src="https://img.shields.io/github/stars/jax-md/jax-md?style=social" align="center">
- lineax - Linear solvers in JAX and Equinox. <img src="https://img.shields.io/github/stars/patrick-kidger/lineax?style=social" align="center">
- optimistix - Nonlinear optimisation (root-finding, least squares, etc.) in JAX+Equinox. <img src="https://img.shields.io/github/stars/patrick-kidger/optimistix?style=social" align="center">
- sympy2jax - Turn SymPy expressions into trainable JAX expressions. <img src="https://img.shields.io/github/stars/patrick-kidger/sympy2jax?style=social" align="center">
- quax - Multiple dispatch over abstract array types in JAX. <img src="https://img.shields.io/github/stars/patrick-kidger/quax?style=social" align="center">
- interpax - Interpolation and function approximation with JAX. <img src="https://img.shields.io/github/stars/f0uriest/interpax?style=social" align="center">
- quadax - Numerical quadrature with JAX. <img src="https://img.shields.io/github/stars/f0uriest/quadax?style=social" align="center">
- optax - Optax is a gradient processing and optimization library for JAX. <img src="https://img.shields.io/github/stars/google-deepmind/optax?style=social" align="center">
- dynamax - State Space Models library in JAX. <img src="https://img.shields.io/github/stars/probml/dynamax?style=social" align="center">
- dynamiqs - High-performance quantum systems simulation with JAX (GPU-accelerated & differentiable solvers). <img src="https://img.shields.io/github/stars/dynamiqs/dynamiqs?style=social" align="center">
- scico - Scientific Computational Imaging COde. <img src="https://img.shields.io/github/stars/lanl/scico?style=social" align="center">
- exojax - đ Automatic differentiable spectrum modeling of exoplanets/brown dwarfs using JAX, compatible with NumPyro and Optax/JAXopt. <img src="https://img.shields.io/github/stars/HajimeKawahara/exojax?style=social" align="center">
- PGMax - Loopy belief propagation for factor graphs on discrete variables in JAX. <img src="https://img.shields.io/github/stars/google-deepmind/PGMax?style=social" align="center">
- evosax - Evolution Strategies in JAX đĻ. <img src="https://img.shields.io/github/stars/RobertTLange/evosax?style=social" align="center">
- evojax - EvoJAX is a scalable, general purpose, hardware-accelerated neuroevolution toolkit. Built on top of the JAX library, this toolkit enables neuroevolution algorithms to work with neural networks running in parallel across multiple TPU/GPUs. <img src="https://img.shields.io/github/stars/google/evojax?style=social" align="center">
- mctx - Monte Carlo tree search in JAX. <img src="https://img.shields.io/github/stars/google-deepmind/mctx?style=social" align="center">
- kfac-jax - Second Order Optimization and Curvature Estimation with K-FAC in JAX. <img src="https://img.shields.io/github/stars/google-deepmind/kfac-jax?style=social" align="center">
- jwave - A JAX-based research framework for differentiable and parallelizable acoustic simulations, on CPU, GPUs and TPUs. <img src="https://img.shields.io/github/stars/ucl-bug/jwave?style=social" align="center">
- jax_cosmo - A differentiable cosmology library in JAX. <img src="https://img.shields.io/github/stars/DifferentiableUniverseInitiative/jax_cosmo?style=social" align="center">
- jaxlie - Rigid transforms + Lie groups in JAX. <img src="https://img.shields.io/github/stars/brentyi/jaxlie?style=social" align="center">
- bayex - Minimal Implementation of Bayesian Optimization in JAX. <img src="https://img.shields.io/github/stars/alonfnt/bayex?style=social" align="center">
- efax - Exponential families for JAX. <img src="https://img.shields.io/github/stars/NeilGirdhar/efax?style=social" align="center">
- GPJax - Gaussian processes in JAX. <img src="https://img.shields.io/github/stars/JaxGaussianProcesses/GPJax?style=social" align="center">
- tinygp - The tiniest of Gaussian Process libraries. <img src="https://img.shields.io/github/stars/dfm/tinygp?style=social" align="center">
- Diffrax - Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. <img src="https://img.shields.io/github/stars/patrick-kidger/diffrax?style=social" align="center">
- jax-md - Differentiable, Hardware Accelerated, Molecular Dynamics. <img src="https://img.shields.io/github/stars/jax-md/jax-md?style=social" align="center">
- lineax - Linear solvers in JAX and Equinox. <img src="https://img.shields.io/github/stars/patrick-kidger/lineax?style=social" align="center">
- optimistix - Nonlinear optimisation (root-finding, least squares, etc.) in JAX+Equinox. <img src="https://img.shields.io/github/stars/patrick-kidger/optimistix?style=social" align="center">
- sympy2jax - Turn SymPy expressions into trainable JAX expressions. <img src="https://img.shields.io/github/stars/patrick-kidger/sympy2jax?style=social" align="center">
- quax - Multiple dispatch over abstract array types in JAX. <img src="https://img.shields.io/github/stars/patrick-kidger/quax?style=social" align="center">
- interpax - Interpolation and function approximation with JAX. <img src="https://img.shields.io/github/stars/f0uriest/interpax?style=social" align="center">
- quadax - Numerical quadrature with JAX. <img src="https://img.shields.io/github/stars/f0uriest/quadax?style=social" align="center">
- optax - Optax is a gradient processing and optimization library for JAX. <img src="https://img.shields.io/github/stars/google-deepmind/optax?style=social" align="center">
- dynamax - State Space Models library in JAX. <img src="https://img.shields.io/github/stars/probml/dynamax?style=social" align="center">
- dynamiqs - High-performance quantum systems simulation with JAX (GPU-accelerated & differentiable solvers). <img src="https://img.shields.io/github/stars/dynamiqs/dynamiqs?style=social" align="center">
- scico - Scientific Computational Imaging COde. <img src="https://img.shields.io/github/stars/lanl/scico?style=social" align="center">
- exojax - đ Automatic differentiable spectrum modeling of exoplanets/brown dwarfs using JAX, compatible with NumPyro and Optax/JAXopt. <img src="https://img.shields.io/github/stars/HajimeKawahara/exojax?style=social" align="center">
- PGMax - Loopy belief propagation for factor graphs on discrete variables in JAX. <img src="https://img.shields.io/github/stars/google-deepmind/PGMax?style=social" align="center">
- evosax - Evolution Strategies in JAX đĻ. <img src="https://img.shields.io/github/stars/RobertTLange/evosax?style=social" align="center">
- evojax - EvoJAX is a scalable, general purpose, hardware-accelerated neuroevolution toolkit. Built on top of the JAX library, this toolkit enables neuroevolution algorithms to work with neural networks running in parallel across multiple TPU/GPUs. <img src="https://img.shields.io/github/stars/google/evojax?style=social" align="center">
- mctx - Monte Carlo tree search in JAX. <img src="https://img.shields.io/github/stars/google-deepmind/mctx?style=social" align="center">
- kfac-jax - Second Order Optimization and Curvature Estimation with K-FAC in JAX. <img src="https://img.shields.io/github/stars/google-deepmind/kfac-jax?style=social" align="center">
- jwave - A JAX-based research framework for differentiable and parallelizable acoustic simulations, on CPU, GPUs and TPUs. <img src="https://img.shields.io/github/stars/ucl-bug/jwave?style=social" align="center">
- jax_cosmo - A differentiable cosmology library in JAX. <img src="https://img.shields.io/github/stars/DifferentiableUniverseInitiative/jax_cosmo?style=social" align="center">
- jaxlie - Rigid transforms + Lie groups in JAX. <img src="https://img.shields.io/github/stars/brentyi/jaxlie?style=social" align="center">
- ott - Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations. <img src="https://img.shields.io/github/stars/ott-jax/ott?style=social" align="center">
- EasyDeL - Accelerate, Optimize performance with streamlined training and serving options with JAX. <img src="https://img.shields.io/github/stars/erfanzar/EasyDeL?style=social" align="center">
- QDax - Accelerated Quality-Diversity. <img src="https://img.shields.io/github/stars/adaptive-intelligent-robotics/QDax?style=social" align="center">
- ott - Optimal transport tools implemented with the JAX framework, to get differentiable, parallel and jit-able computations. <img src="https://img.shields.io/github/stars/ott-jax/ott?style=social" align="center">
- XLB - XLB: Accelerated Lattice Boltzmann (XLB) for Physics-based ML. <img src="https://img.shields.io/github/stars/Autodesk/XLB?style=social" align="center">
- EasyDeL - Accelerate, Optimize performance with streamlined training and serving options with JAX. <img src="https://img.shields.io/github/stars/erfanzar/EasyDeL?style=social" align="center">
- QDax - Accelerated Quality-Diversity. <img src="https://img.shields.io/github/stars/adaptive-intelligent-robotics/QDax?style=social" align="center">
- paxml - Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates. <img src="https://img.shields.io/github/stars/google/paxml?style=social" align="center">
- econpizza - Solve nonlinear heterogeneous agent models. <img src="https://img.shields.io/github/stars/gboehl/econpizza?style=social" align="center">
- fedjax - FedJAX is a JAX-based open source library for Federated Learning simulations that emphasizes ease-of-use in research. <img src="https://img.shields.io/github/stars/google/fedjax?style=social" align="center">
- neural-tangents - Fast and Easy Infinite Neural Networks in Python. <img src="https://img.shields.io/github/stars/google/neural-tangents?style=social" align="center">
- jax-fem - Differentiable Finite Element Method with JAX. <img src="https://img.shields.io/github/stars/deepmodeling/jax-fem?style=social" align="center">
- veros - The versatile ocean simulator, in pure Python, powered by JAX. <img src="https://img.shields.io/github/stars/team-ocean/veros?style=social" align="center">
- JAXFLUIDS - Differentiable Fluid Dynamics Package. <img src="https://img.shields.io/github/stars/tumaer/JAXFLUIDS?style=social" align="center">
- paxml - Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates. <img src="https://img.shields.io/github/stars/google/paxml?style=social" align="center">
- econpizza - Solve nonlinear heterogeneous agent models. <img src="https://img.shields.io/github/stars/gboehl/econpizza?style=social" align="center">
- fedjax - FedJAX is a JAX-based open source library for Federated Learning simulations that emphasizes ease-of-use in research. <img src="https://img.shields.io/github/stars/google/fedjax?style=social" align="center">
- neural-tangents - Fast and Easy Infinite Neural Networks in Python. <img src="https://img.shields.io/github/stars/google/neural-tangents?style=social" align="center">
- jax-fem - Differentiable Finite Element Method with JAX. <img src="https://img.shields.io/github/stars/deepmodeling/jax-fem?style=social" align="center">
- veros - The versatile ocean simulator, in pure Python, powered by JAX. <img src="https://img.shields.io/github/stars/team-ocean/veros?style=social" align="center">
- JAXFLUIDS - Differentiable Fluid Dynamics Package. <img src="https://img.shields.io/github/stars/tumaer/JAXFLUIDS?style=social" align="center">
- JaxMARL - Multi-Agent Reinforcement Learning with JAX. <img src="https://img.shields.io/github/stars/FLAIROx/JaxMARL?style=social" align="center">
- craftax - (Crafter + NetHack) in JAX. ICML 2024 Spotlight. <img src="https://img.shields.io/github/stars/MichaelTMatthews/Craftax?style=social" align="center">
- navix - Accelerated minigrid environments with JAX. <img src="https://img.shields.io/github/stars/epignatelli/navix?style=social" align="center">
- klujax - Solve sparse linear systems in JAX using the KLU algorithm. <img src="https://img.shields.io/github/stars/flaport/klujax?style=social" align="center">
- coreax - A library for coreset algorithms, written in Jax for fast execution and GPU support. <img src="https://img.shields.io/github/stars/gchq/coreax?style=social" align="center">
- Jaxley - Differentiable neuron simulations with biophysical detail on CPU, GPU, or TPU. <img src="https://img.shields.io/github/stars/jaxleyverse/jaxley?style=social" align="center">
-
Up and Coming Libraries
- traceax - Stochastic trace estimation using JAX. <img src="https://img.shields.io/github/stars/mancusolab/traceax?style=social" align="center">
- graphax - Cross-Country Elimination in JAX. <img src="https://img.shields.io/github/stars/jamielohoff/graphax?style=social" align="center">
- cd_dynamax - Extension of dynamax repo to cases with continuous-time dynamics with measurements sampled at possibly irregular discrete times. Allows generic inference of dynamical systems parameters from partial noisy observations via auto-differentiable filtering, SGD, and HMC. <img src="https://img.shields.io/github/stars/hd-UQ/cd_dynamax?style=social" align="center">
- traceax - Stochastic trace estimation using JAX. <img src="https://img.shields.io/github/stars/mancusolab/traceax?style=social" align="center">
- graphax - Cross-Country Elimination in JAX. <img src="https://img.shields.io/github/stars/jamielohoff/graphax?style=social" align="center">
- cd_dynamax - Extension of dynamax repo to cases with continuous-time dynamics with measurements sampled at possibly irregular discrete times. Allows generic inference of dynamical systems parameters from partial noisy observations via auto-differentiable filtering, SGD, and HMC. <img src="https://img.shields.io/github/stars/hd-UQ/cd_dynamax?style=social" align="center">
-
Inactive Libraries
- Haiku - JAX-based neural network library. <img src="https://img.shields.io/github/stars/google-deepmind/dm-haiku?style=social" align="center">
- Haiku - JAX-based neural network library. <img src="https://img.shields.io/github/stars/google-deepmind/dm-haiku?style=social" align="center">
- jraph - A Graph Neural Network Library in Jax. <img src="https://img.shields.io/github/stars/google-deepmind/jraph?style=social" align="center">
- SymJAX - symbolic CPU/GPU/TPU programming. <img src="https://img.shields.io/github/stars/SymJAX/SymJAX?style=social" align="center">
- coax - Modular framework for Reinforcement Learning in python. <img src="https://img.shields.io/github/stars/coax-dev/coax?style=social" align="center">
- eqxvision - A Python package of computer vision models for the Equinox ecosystem. <img src="https://img.shields.io/github/stars/paganpasta/eqxvision?style=social" align="center">
- jaxfit - GPU/TPU accelerated nonlinear least-squares curve fitting using JAX. <img src="https://img.shields.io/github/stars/dipolar-quantum-gases/jaxfit?style=social" align="center">
- safejax - Serialize JAX, Flax, Haiku, or Objax model params with đ¤`safetensors`. <img src="https://img.shields.io/github/stars/alvarobartt/safejax?style=social" align="center">
- kernex - Stencil computations in JAX. <img src="https://img.shields.io/github/stars/ASEM000/kernex?style=social" align="center">
- lorax - LoRA for arbitrary JAX models and functions. <img src="https://img.shields.io/github/stars/davisyoshida/lorax?style=social" align="center">
- mcx - Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX. <img src="https://img.shields.io/github/stars/rlouf/mcx?style=social" align="center">
- einshape - DSL-based reshaping library for JAX and other frameworks. <img src="https://img.shields.io/github/stars/google-deepmind/einshape?style=social" align="center">
- jax-flows - Normalizing Flows in JAX đ. <img src="https://img.shields.io/github/stars/ChrisWaites/jax-flows?style=social" align="center">
- sklearn-jax-kernels - Composable kernels for scikit-learn implemented in JAX. <img src="https://img.shields.io/github/stars/ExpectationMax/sklearn-jax-kernels?style=social" align="center">
- deltapv - A photovoltaic simulator with automatic differentiation. <img src="https://img.shields.io/github/stars/romanodev/deltapv?style=social" align="center">
- jraph - A Graph Neural Network Library in Jax. <img src="https://img.shields.io/github/stars/google-deepmind/jraph?style=social" align="center">
- SymJAX - symbolic CPU/GPU/TPU programming. <img src="https://img.shields.io/github/stars/SymJAX/SymJAX?style=social" align="center">
- coax - Modular framework for Reinforcement Learning in python. <img src="https://img.shields.io/github/stars/coax-dev/coax?style=social" align="center">
- eqxvision - A Python package of computer vision models for the Equinox ecosystem. <img src="https://img.shields.io/github/stars/paganpasta/eqxvision?style=social" align="center">
- jaxfit - GPU/TPU accelerated nonlinear least-squares curve fitting using JAX. <img src="https://img.shields.io/github/stars/dipolar-quantum-gases/jaxfit?style=social" align="center">
- safejax - Serialize JAX, Flax, Haiku, or Objax model params with đ¤`safetensors`. <img src="https://img.shields.io/github/stars/alvarobartt/safejax?style=social" align="center">
- kernex - Stencil computations in JAX. <img src="https://img.shields.io/github/stars/ASEM000/kernex?style=social" align="center">
- lorax - LoRA for arbitrary JAX models and functions. <img src="https://img.shields.io/github/stars/davisyoshida/lorax?style=social" align="center">
- mcx - Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX. <img src="https://img.shields.io/github/stars/rlouf/mcx?style=social" align="center">
- einshape - DSL-based reshaping library for JAX and other frameworks. <img src="https://img.shields.io/github/stars/google-deepmind/einshape?style=social" align="center">
- jax-flows - Normalizing Flows in JAX đ. <img src="https://img.shields.io/github/stars/ChrisWaites/jax-flows?style=social" align="center">
- sklearn-jax-kernels - Composable kernels for scikit-learn implemented in JAX. <img src="https://img.shields.io/github/stars/ExpectationMax/sklearn-jax-kernels?style=social" align="center">
- deltapv - A photovoltaic simulator with automatic differentiation. <img src="https://img.shields.io/github/stars/romanodev/deltapv?style=social" align="center">
- cr-sparse - Functional models and algorithms for sparse signal processing. <img src="https://img.shields.io/github/stars/carnotresearch/cr-sparse?style=social" align="center">
- flaxvision - A selection of neural network models ported from torchvision for JAX & Flax. <img src="https://img.shields.io/github/stars/rolandgvc/flaxvision?style=social" align="center">
- imax - Image augmentation library for Jax. <img src="https://img.shields.io/github/stars/4rtemi5/imax?style=social" align="center">
- jax-unirep - Reimplementation of the UniRep protein featurization model. <img src="https://img.shields.io/github/stars/ElArkk/jax-unirep?style=social" align="center">
- parallax - Immutable Torch Modules for JAX. <img src="https://img.shields.io/github/stars/srush/parallax?style=social" align="center">
- elegy - A High Level API for Deep Learning in JAX. <img src="https://img.shields.io/github/stars/poets-ai/elegy?style=social" align="center">
- objax - Objax is an open source machine learning framework that accelerates research and learning thanks to a minimalist object-oriented design and a readable code base. <img src="https://img.shields.io/github/stars/google/objax?style=social" align="center">
- jaxrl - JAX (Flax) implementation of algorithms for Deep Reinforcement Learning with continuous action spaces. <img src="https://img.shields.io/github/stars/ikostrikov/jaxrl?style=social" align="center">
- cr-sparse - Functional models and algorithms for sparse signal processing. <img src="https://img.shields.io/github/stars/carnotresearch/cr-sparse?style=social" align="center">
- flaxvision - A selection of neural network models ported from torchvision for JAX & Flax. <img src="https://img.shields.io/github/stars/rolandgvc/flaxvision?style=social" align="center">
- imax - Image augmentation library for Jax. <img src="https://img.shields.io/github/stars/4rtemi5/imax?style=social" align="center">
- jax-unirep - Reimplementation of the UniRep protein featurization model. <img src="https://img.shields.io/github/stars/ElArkk/jax-unirep?style=social" align="center">
- parallax - Immutable Torch Modules for JAX. <img src="https://img.shields.io/github/stars/srush/parallax?style=social" align="center">
- elegy - A High Level API for Deep Learning in JAX. <img src="https://img.shields.io/github/stars/poets-ai/elegy?style=social" align="center">
- objax - Objax is an open source machine learning framework that accelerates research and learning thanks to a minimalist object-oriented design and a readable code base. <img src="https://img.shields.io/github/stars/google/objax?style=social" align="center">
- jaxrl - JAX (Flax) implementation of algorithms for Deep Reinforcement Learning with continuous action spaces. <img src="https://img.shields.io/github/stars/ikostrikov/jaxrl?style=social" align="center">
- jax-resnet - Implementations and checkpoints for ResNet, Wide ResNet, ResNeXt, ResNet-D, and ResNeSt in JAX (Flax). <img src="https://img.shields.io/github/stars/n2cholas/jax-resnet?style=social" align="center">
-
-
Tutorials and Blog Posts
-
Inactive Libraries
- Learning JAX as a PyTorch developer
- Massively parallel MCMC with JAX
- Achieving Over 4000x Speedups and Meta-Evolving Discoveries with PureJaxRL
- How to add a progress bar to JAX scans and loops
- MCMC in JAX with benchmarks: 3 ways to write a sampler
- Deterministic ADVI in JAX
- Massively parallel MCMC with JAX
- Achieving Over 4000x Speedups and Meta-Evolving Discoveries with PureJaxRL
- How to add a progress bar to JAX scans and loops
- MCMC in JAX with benchmarks: 3 ways to write a sampler
- Deterministic ADVI in JAX
- Learning JAX as a PyTorch developer
-
-
Models and Projects
-
Inactive Libraries
- whisper-jax - JAX implementation of OpenAI's Whisper model for up to 70x speed-up on TPU. <img src="https://img.shields.io/github/stars/sanchit-gandhi/whisper-jax?style=social" align="center">
- esm2quinox - An implementation of ESM2 in Equinox+JAX. <img src="https://img.shields.io/github/stars/patrick-kidger/esm2quinox?style=social" align="center">
- whisper-jax - JAX implementation of OpenAI's Whisper model for up to 70x speed-up on TPU. <img src="https://img.shields.io/github/stars/sanchit-gandhi/whisper-jax?style=social" align="center">
- esm2quinox - An implementation of ESM2 in Equinox+JAX. <img src="https://img.shields.io/github/stars/patrick-kidger/esm2quinox?style=social" align="center">
-
Programming Languages
Sub Categories
Keywords
jax
104
deep-learning
24
reinforcement-learning
22
machine-learning
22
python
21
equinox
14
neural-networks
14
optimization
10
flax
10
research
8
bayesian-inference
8
gpu
8
deep-reinforcement-learning
7
probabilistic-programming
7
large-language-models
6
automatic-differentiation
6
gpt
6
computer-vision
6
convex-optimization
4
llm
4
typing
4
robotics
4
kernel
4
python-typing
4
ppo
4
pytorch
4
gaussian-processes
4
reinforcement-learning-algorithms
4
gym
4
statistics
4
deep-neural-networks
4
transformers
4
gpu-acceleration
4
sparse-linear-systems
3
marl
3
differentiable-simulations
3
sampling-methods
2
kalman-filter
2
hamiltonian-monte-carlo
2
bayesian-optimization
2
hidden-markov-models
2
gaussian-process-regression
2
multiple-dispatch
2
sympy
2
optimisation
2
linear-algebra
2
neural-differential-equations
2
dynamical-systems
2
differential-equations
2
gromov-wasserstein
2