Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

awesome-jax

JAX - A curated list of resources https://github.com/google/jax
https://github.com/n2cholas/awesome-jax

Last synced: 2 days ago
JSON representation

  • Libraries

      • Jraph - Lightweight graph neural network library. <img src="https://img.shields.io/github/stars/deepmind/jraph?style=social" align="center">
      • Scenic - A Jax Library for Computer Vision Research and Beyond. <img src="https://img.shields.io/github/stars/google-research/scenic?style=social" align="center">
      • Optax - Gradient processing and optimization library. <img src="https://img.shields.io/github/stars/deepmind/optax?style=social" align="center">
      • JAX, M.D. - Accelerated, differential molecular dynamics. <img src="https://img.shields.io/github/stars/google/jax-md?style=social" align="center">
      • Distrax - Reimplementation of TensorFlow Probability, containing probability distributions and bijectors. <img src="https://img.shields.io/github/stars/deepmind/distrax?style=social" align="center">
      • Flax - Centered on flexibility and clarity. <img src="https://img.shields.io/github/stars/google/flax?style=social" align="center">
      • Objax - Has an object oriented design similar to PyTorch. <img src="https://img.shields.io/github/stars/google/objax?style=social" align="center">
      • Elegy - A High Level API for Deep Learning in JAX. Supports Flax, Haiku, and Optax. <img src="https://img.shields.io/github/stars/poets-ai/elegy?style=social" align="center">
      • Trax - "Batteries included" deep learning library focused on providing solutions for common workloads. <img src="https://img.shields.io/github/stars/google/trax?style=social" align="center">
      • Neural Tangents - High-level API for specifying neural networks of both finite and _infinite_ width. <img src="https://img.shields.io/github/stars/google/neural-tangents?style=social" align="center">
      • HuggingFace - Ecosystem of pretrained Transformers for a wide range of natural language tasks (Flax). <img src="https://img.shields.io/github/stars/huggingface/transformers?style=social" align="center">
      • Equinox - Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX. <img src="https://img.shields.io/github/stars/patrick-kidger/equinox?style=social" align="center">
      • Scenic - A Jax Library for Computer Vision Research and Beyond. <img src="https://img.shields.io/github/stars/google-research/scenic?style=social" align="center">
      • Levanter - Legible, Scalable, Reproducible Foundation Models with Named Tensors and JAX. <img src="https://img.shields.io/github/stars/stanford-crfm/levanter?style=social" align="center">
      • EasyLM - LLMs made easy: Pre-training, finetuning, evaluating and serving LLMs in JAX/Flax. <img src="https://img.shields.io/github/stars/young-geng/EasyLM?style=social" align="center">
      • NumPyro - Probabilistic programming based on the Pyro library. <img src="https://img.shields.io/github/stars/pyro-ppl/numpyro?style=social" align="center">
      • Chex - Utilities to write and test reliable JAX code. <img src="https://img.shields.io/github/stars/deepmind/chex?style=social" align="center">
      • Coax - Turn RL papers into code, the easy way. <img src="https://img.shields.io/github/stars/coax-dev/coax?style=social" align="center">
      • cvxpylayers - Construct differentiable convex optimization layers. <img src="https://img.shields.io/github/stars/cvxgrp/cvxpylayers?style=social" align="center">
      • TensorLy - Tensor learning made simple. <img src="https://img.shields.io/github/stars/tensorly/tensorly?style=social" align="center">
      • NetKet - Machine Learning toolbox for Quantum Physics. <img src="https://img.shields.io/github/stars/netket/netket?style=social" align="center">
      • Fortuna - AWS library for Uncertainty Quantification in Deep Learning. <img src="https://img.shields.io/github/stars/awslabs/fortuna?style=social" align="center">
      • BlackJAX - Library of samplers for JAX. <img src="https://img.shields.io/github/stars/blackjax-devs/blackjax?style=social" align="center">
      • Jraph - Lightweight graph neural network library. <img src="https://img.shields.io/github/stars/deepmind/jraph?style=social" align="center">
      • RLax - Library for implementing reinforcement learning agents. <img src="https://img.shields.io/github/stars/deepmind/rlax?style=social" align="center">
    • New Libraries

      • Oryx - Probabilistic programming language based on program transformations.
      • flaxmodels - Pretrained models for Jax/Flax. <img src="https://img.shields.io/github/stars/matthias-wright/flaxmodels?style=social" align="center">
      • evosax - JAX-Based Evolution Strategies <img src="https://img.shields.io/github/stars/RobertTLange/evosax?style=social" align="center">
      • ALX - Open-source library for distributed matrix factorization using Alternating Least Squares, more info in [_ALX: Large Scale Matrix Factorization on TPUs_](https://arxiv.org/abs/2112.02194).
      • FedJAX - Federated learning in JAX, built on Optax and Haiku. <img src="https://img.shields.io/github/stars/google/fedjax?style=social" align="center">
      • Equivariant MLP - Construct equivariant neural network layers. <img src="https://img.shields.io/github/stars/mfinzi/equivariant-MLP?style=social" align="center">
      • jax-resnet - Implementations and checkpoints for ResNet variants in Flax. <img src="https://img.shields.io/github/stars/n2cholas/jax-resnet?style=social" align="center">
      • Parallax - Immutable Torch Modules for JAX. <img src="https://img.shields.io/github/stars/srush/parallax?style=social" align="center">
      • jax-unirep - Library implementing the [UniRep model](https://www.nature.com/articles/s41592-019-0598-1) for protein machine learning applications. <img src="https://img.shields.io/github/stars/ElArkk/jax-unirep?style=social" align="center">
      • jax-flows - Normalizing flows in JAX. <img src="https://img.shields.io/github/stars/ChrisWaites/jax-flows?style=social" align="center">
      • sklearn-jax-kernels - `scikit-learn` kernel matrices using JAX. <img src="https://img.shields.io/github/stars/ExpectationMax/sklearn-jax-kernels?style=social" align="center">
      • jax-cosmo - Differentiable cosmology library. <img src="https://img.shields.io/github/stars/DifferentiableUniverseInitiative/jax_cosmo?style=social" align="center">
      • efax - Exponential Families in JAX. <img src="https://img.shields.io/github/stars/NeilGirdhar/efax?style=social" align="center">
      • imax - Image augmentations and transformations. <img src="https://img.shields.io/github/stars/4rtemi5/imax?style=social" align="center">
      • FlaxVision - Flax version of TorchVision. <img src="https://img.shields.io/github/stars/rolandgvc/flaxvision?style=social" align="center">
      • Optimal Transport Tools - Toolbox that bundles utilities to solve optimal transport problems.
      • delta PV - A photovoltaic simulator with automatic differentation. <img src="https://img.shields.io/github/stars/romanodev/deltapv?style=social" align="center">
      • jaxlie - Lie theory library for rigid body transformations and optimization. <img src="https://img.shields.io/github/stars/brentyi/jaxlie?style=social" align="center">
      • BRAX - Differentiable physics engine to simulate environments along with learning algorithms to train agents for these environments. <img src="https://img.shields.io/github/stars/google/brax?style=social" align="center">
      • flaxmodels - Pretrained models for Jax/Flax. <img src="https://img.shields.io/github/stars/matthias-wright/flaxmodels?style=social" align="center">
      • CR.Sparse - XLA accelerated algorithms for sparse representations and compressive sensing. <img src="https://img.shields.io/github/stars/carnotresearch/cr-sparse?style=social" align="center">
      • exojax - Automatic differentiable spectrum modeling of exoplanets/brown dwarfs compatible to JAX. <img src="https://img.shields.io/github/stars/HajimeKawahara/exojax?style=social" align="center">
      • JAXopt - Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX. <img src="https://img.shields.io/github/stars/google/jaxopt?style=social" align="center">
      • bayex - Bayesian Optimization powered by JAX. <img src="https://img.shields.io/github/stars/alonfnt/bayex?style=social" align="center">
      • JaxDF - Framework for differentiable simulators with arbitrary discretizations. <img src="https://img.shields.io/github/stars/ucl-bug/jaxdf?style=social" align="center">
      • tree-math - Convert functions that operate on arrays into functions that operate on PyTrees. <img src="https://img.shields.io/github/stars/google/tree-math?style=social" align="center">
      • jax-models - Implementations of research papers originally without code or code written with frameworks other than JAX. <img src="https://img.shields.io/github/stars/DarshanDeshpande/jax-modelsa?style=social" align="center">
      • PGMax - A framework for building discrete Probabilistic Graphical Models (PGM's) and running inference inference on them via JAX. <img src="https://img.shields.io/github/stars/vicariousinc/pgmax?style=social" align="center">
      • EvoJAX - Hardware-Accelerated Neuroevolution <img src="https://img.shields.io/github/stars/google/evojax?style=social" align="center">
      • evosax - JAX-Based Evolution Strategies <img src="https://img.shields.io/github/stars/RobertTLange/evosax?style=social" align="center">
      • SymJAX - Symbolic CPU/GPU/TPU programming. <img src="https://img.shields.io/github/stars/SymJAX/SymJAX?style=social" align="center">
      • mcx - Express & compile probabilistic programs for performant inference. <img src="https://img.shields.io/github/stars/rlouf/mcx?style=social" align="center">
      • Einshape - DSL-based reshaping library for JAX and other frameworks. <img src="https://img.shields.io/github/stars/deepmind/einshape?style=social" align="center">
      • Diffrax - Numerical differential equation solvers in JAX. <img src="https://img.shields.io/github/stars/patrick-kidger/diffrax?style=social" align="center">
      • tinygp - The _tiniest_ of Gaussian process libraries in JAX. <img src="https://img.shields.io/github/stars/dfm/tinygp?style=social" align="center">
      • gymnax - Reinforcement Learning Environments with the well-known gym API. <img src="https://img.shields.io/github/stars/RobertTLange/gymnax?style=social" align="center">
      • Mctx - Monte Carlo tree search algorithms in native JAX. <img src="https://img.shields.io/github/stars/deepmind/mctx?style=social" align="center">
      • TF2JAX - Convert functions/graphs to JAX functions. <img src="https://img.shields.io/github/stars/deepmind/tf2jax?style=social" align="center">
      • jwave - A library for differentiable acoustic simulations <img src="https://img.shields.io/github/stars/ucl-bug/jwave?style=social" align="center">
      • GPJax - Gaussian processes in JAX.
      • Jumanji - A Suite of Industry-Driven Hardware-Accelerated RL Environments written in JAX. <img src="https://img.shields.io/github/stars/instadeepai/jumanji?style=social" align="center">
      • Eqxvision - Equinox version of Torchvision. <img src="https://img.shields.io/github/stars/paganpasta/eqxvision?style=social" align="center">
      • JAXFit - Accelerated curve fitting library for nonlinear least-squares problems (see [arXiv paper](https://arxiv.org/abs/2208.12187)). <img src="https://img.shields.io/github/stars/dipolar-quantum-gases/jaxfit?style=social" align="center">
      • econpizza - Solve macroeconomic models with hetereogeneous agents using JAX. <img src="https://img.shields.io/github/stars/gboehl/econpizza?style=social" align="center">
      • SPU - A domain-specific compiler and runtime suite to run JAX code with MPC(Secure Multi-Party Computation). <img src="https://img.shields.io/github/stars/secretflow/spu?style=social" align="center">
      • jax-tqdm - Add a tqdm progress bar to JAX scans and loops. <img src="https://img.shields.io/github/stars/jeremiecoullon/jax-tqdm?style=social" align="center">
      • safejax - Serialize JAX, Flax, Haiku, or Objax model params with 🤗`safetensors`. <img src="https://img.shields.io/github/stars/alvarobartt/safejax?style=social" align="center">
      • Kernex - Differentiable stencil decorators in JAX. <img src="https://img.shields.io/github/stars/ASEM000/kernex?style=social" align="center">
      • MaxText - A simple, performant and scalable Jax LLM written in pure Python/Jax and targeting Google Cloud TPUs. <img src="https://img.shields.io/github/stars/google/maxtext?style=social" align="center">
      • Pax - A Jax-based machine learning framework for training large scale models. <img src="https://img.shields.io/github/stars/google/paxml?style=social" align="center">
      • Praxis - The layer library for Pax with a goal to be usable by other JAX-based ML projects. <img src="https://img.shields.io/github/stars/google/praxis?style=social" align="center">
      • purejaxrl - Vectorisable, end-to-end RL algorithms in JAX. <img src="https://img.shields.io/github/stars/luchris429/purejaxrl?style=social" align="center">
      • Lorax - Automatically apply LoRA to JAX models (Flax, Haiku, etc.)
      • SCICO - Scientific computational imaging in JAX. <img src="https://img.shields.io/github/stars/lanl/scico?style=social" align="center">
      • Spyx - Spiking Neural Networks in JAX for machine learning on neuromorphic hardware. <img src="https://img.shields.io/github/stars/kmheckel/spyx?style=social" align="center">
      • BrainPy - Brain Dynamics Programming in Python. <img src="https://img.shields.io/github/stars/brainpy/BrainPy?style=social" align="center">
      • OTT-JAX - Optimal transport tools in JAX. <img src="https://img.shields.io/github/stars/ott-jax/ott?style=social" align="center">
      • QDax - Quality Diversity optimization in Jax. <img src="https://img.shields.io/github/stars/adaptive-intelligent-robotics/QDax?style=social" align="center">
      • JAX Toolbox - Nightly CI and optimized examples for JAX on NVIDIA GPUs using libraries such as T5x, Paxml, and Transformer Engine. <img src="https://img.shields.io/github/stars/NVIDIA/JAX-Toolbox?style=social" align="center">
      • dynamiqs - High-performance and differentiable simulations of quantum systems with JAX. <img src="https://img.shields.io/github/stars/dynamiqs/dynamiqs?style=social" align="center">
  • Models and Projects

    • JAX

      • Amortized Bayesian Optimization - Code related to [_Amortized Bayesian Optimization over Discrete Spaces_](http://www.auai.org/uai2020/proceedings/329_main_paper.pdf).
      • Accurate Quantized Training - Tools and libraries for running and analyzing neural network quantization experiments in JAX and Flax.
      • BNN-HMC - Implementation for the paper [_What Are Bayesian Neural Network Posteriors Really Like?_](https://arxiv.org/abs/2104.14421).
      • JAX-DFT - One-dimensional density functional theory (DFT) in JAX, with implementation of [_Kohn-Sham equations as regularizer: building prior knowledge into machine-learned physics_](https://journals.aps.org/prl/abstract/10.1103/PhysRevLett.126.036401).
      • Robust Loss - Reference code for the paper [_A General and Adaptive Robust Loss Function_](https://arxiv.org/abs/1701.03077).
      • Symbolic Functionals - Demonstration from [_Evolving symbolic density functionals_](https://arxiv.org/abs/2203.02540).
      • TriMap - Official JAX implementation of [_TriMap: Large-scale Dimensionality Reduction Using Triplets_](https://arxiv.org/abs/1910.00204).
    • Flax

      • Performer - Flax implementation of the Performer (linear transformer via FAVOR+) architecture.
      • JaxNeRF - Implementation of [_NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis_](http://www.matthewtancik.com/nerf) with multi-device GPU/TPU support.
      • RegNeRF - Official implementation of [_RegNeRF: Regularizing Neural Radiance Fields for View Synthesis from Sparse Inputs_](https://m-niemeyer.github.io/regnerf/).
      • gMLP - Implementation of [_Pay Attention to MLPs_](https://arxiv.org/abs/2105.08050).
      • MLP Mixer - Minimal implementation of [_MLP-Mixer: An all-MLP Architecture for Vision_](https://arxiv.org/abs/2105.01601).
      • Distributed Shampoo - Implementation of [_Second Order Optimization Made Practical_](https://arxiv.org/abs/2002.09018).
      • FNet - Official implementation of [_FNet: Mixing Tokens with Fourier Transforms_](https://arxiv.org/abs/2105.03824).
      • GFSA - Official implementation of [_Learning Graph Structure With A Finite-State Automaton Layer_](https://arxiv.org/abs/2007.04929).
      • IPA-GNN - Official implementation of [_Learning to Execute Programs with Instruction Pointer Attention Graph Neural Networks_](https://arxiv.org/abs/2010.12621).
      • Flax Models - Collection of models and methods implemented in Flax.
      • Protein LM - Implements BERT and autoregressive models for proteins, as described in [_Biological Structure and Function Emerge from Scaling Unsupervised Learning to 250 Million Protein Sequences_](https://www.biorxiv.org/content/10.1101/622803v1.full) and [_ProGen: Language Modeling for Protein Generation_](https://www.biorxiv.org/content/10.1101/2020.03.07.982272v2).
      • Slot Attention - Reference implementation for [_Differentiable Patch Selection for Image Recognition_](https://arxiv.org/abs/2104.03059).
      • ARDM - Official implementation of [_Autoregressive Diffusion Models_](https://arxiv.org/abs/2110.02037).
      • D3PM - Official implementation of [_Structured Denoising Diffusion Models in Discrete State-Spaces_](https://arxiv.org/abs/2107.03006).
      • Gumbel-max Causal Mechanisms - Code for [_Learning Generalized Gumbel-max Causal Mechanisms_](https://arxiv.org/abs/2111.06888), with extra code in [GuyLor/gumbel_max_causal_gadgets_part2](https://github.com/GuyLor/gumbel_max_causal_gadgets_part2).
      • Latent Programmer - Code for the ICML 2021 paper [_Latent Programmer: Discrete Latent Codes for Program Synthesis_](https://arxiv.org/abs/2012.00377).
      • SNeRG - Official implementation of [_Baking Neural Radiance Fields for Real-Time View Synthesis_](https://phog.github.io/snerg).
      • Spin-weighted Spherical CNNs - Adaptation of [_Spin-Weighted Spherical CNNs_](https://arxiv.org/abs/2006.10731).
      • VDVAE - Adaptation of [_Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on Images_](https://arxiv.org/abs/2011.10650), original code at [openai/vdvae](https://github.com/openai/vdvae).
      • MUSIQ - Checkpoints and model inference code for the ICCV 2021 paper [_MUSIQ: Multi-scale Image Quality Transformer_](https://arxiv.org/abs/2108.05997)
      • AQuaDem - Official implementation of [_Continuous Control with Action Quantization from Demonstrations_](https://arxiv.org/abs/2110.10149).
      • Combiner - Official implementation of [_Combiner: Full Attention Transformer with Sparse Computation Cost_](https://arxiv.org/abs/2107.05768).
      • Dreamfields - Official implementation of the ICLR 2022 paper [_Progressive Distillation for Fast Sampling of Diffusion Models_](https://ajayj.com/dreamfields).
      • GIFT - Official implementation of [_Gradual Domain Adaptation in the Wild:When Intermediate Distributions are Absent_](https://arxiv.org/abs/2106.06080).
      • Light Field Neural Rendering - Official implementation of [_Light Field Neural Rendering_](https://arxiv.org/abs/2112.09687).
      • Sharpened Cosine Similarity in JAX by Raphael Pisoni - A JAX/Flax implementation of the Sharpened Cosine Similarity layer.
    • Haiku

      • AlphaFold - Implementation of the inference pipeline of AlphaFold v2.0, presented in [_Highly accurate protein structure prediction with AlphaFold_](https://www.nature.com/articles/s41586-021-03819-2).
      • Adversarial Robustness - Reference code for [_Uncovering the Limits of Adversarial Training against Norm-Bounded Adversarial Examples_](https://arxiv.org/abs/2010.03593) and [_Fixing Data Augmentation to Improve Adversarial Robustness_](https://arxiv.org/abs/2103.01946).
      • Bootstrap Your Own Latent - Implementation for the paper [_Bootstrap your own latent: A new approach to self-supervised Learning_](https://arxiv.org/abs/2006.07733).
      • Gated Linear Networks - GLNs are a family of backpropagation-free neural networks.
      • Glassy Dynamics - Open source implementation of the paper [_Unveiling the predictive power of static structure in glassy systems_](https://www.nature.com/articles/s41567-020-0842-8).
      • MMV - Code for the models in [_Self-Supervised MultiModal Versatile Networks_](https://arxiv.org/abs/2006.16228).
      • Normalizer-Free Networks - Official Haiku implementation of [_NFNets_](https://arxiv.org/abs/2102.06171).
      • OGB-LSC - This repository contains DeepMind's entry to the [PCQM4M-LSC](https://ogb.stanford.edu/kddcup2021/pcqm4m/) (quantum chemistry) and [MAG240M-LSC](https://ogb.stanford.edu/kddcup2021/mag240m/) (academic graph)
      • OGB Large-Scale Challenge - LSC).
      • Persistent Evolution Strategies - Code used for the paper [_Unbiased Gradient Estimation in Unrolled Computation Graphs with Persistent Evolution Strategies_](http://proceedings.mlr.press/v139/vicol21a.html).
      • WikiGraphs - Baseline code to reproduce results in [_WikiGraphs: A Wikipedia Text - Knowledge Graph Paired Datase_](https://aclanthology.org/2021.textgraphs-1.7).
    • Trax

      • Reformer - Implementation of the Reformer (efficient transformer) architecture.
  • Videos

  • Papers

  • Tutorials and Blog Posts

  • Books

    • NumPyro

      • Jax in Action - A hands-on guide to using JAX for deep learning and other mathematically-intensive applications.
  • Community