Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/NVIDIA/JAX-Toolbox

JAX-Toolbox
https://github.com/NVIDIA/JAX-Toolbox

Last synced: about 2 months ago
JSON representation

JAX-Toolbox

Awesome Lists containing this project

README

        

# **JAX Toolbox**

[![License Apache 2.0](https://badgen.net/badge/license/apache2.0/blue)](https://github.com/NVIDIA/JAX-Toolbox/blob/main/LICENSE.md)
[![Build](https://badgen.net/badge/build/check-status/blue)](#build-pipeline-status)

JAX Toolbox provides a public CI, Docker images for popular JAX libraries, and optimized JAX examples to simplify and enhance your JAX development experience on NVIDIA GPUs. It supports JAX libraries such as [MaxText](https://github.com/google/maxtext), [Paxml](https://github.com/google/paxml), and [Pallas](https://jax.readthedocs.io/en/latest/pallas/quickstart.html).

## Frameworks and Supported Models
We support and test the following JAX frameworks and model architectures. More details about each model and available containers can be found in their respective READMEs.

| Framework | Models | Use cases | Container |
| :--- | :---: | :---: | :---: |
| [maxtext](./rosetta/rosetta/projects/maxtext)| GPT, LLaMA, Gemma, Mistral, Mixtral | pretraining | `ghcr.io/nvidia/jax:maxtext` |
| [paxml](./rosetta/rosetta/projects/pax) | GPT, LLaMA, MoE | pretraining, fine-tuning, LoRA | `ghcr.io/nvidia/jax:pax` |
| [t5x](./rosetta/rosetta/projects/t5x) | T5, ViT | pre-training, fine-tuning | `ghcr.io/nvidia/jax:t5x` |
| [t5x](./rosetta/rosetta/projects/imagen) | Imagen | pre-training | `ghcr.io/nvidia/t5x:imagen-2023-10-02.v3` |
| [big vision](./rosetta/rosetta/projects/paligemma) | PaliGemma | fine-tuning, evaluation | `ghcr.io/nvidia/jax:gemma` |
| levanter | GPT, LLaMA, MPT, Backpacks | pretraining, fine-tuning | `ghcr.io/nvidia/jax:levanter` |

# Build Pipeline Status










Components
Container
Build
Test










ghcr.io/nvidia/jax:base








[no tests]









ghcr.io/nvidia/jax:jax




























































ghcr.io/nvidia/jax:levanter





























ghcr.io/nvidia/jax:equinox












[tests disabled]










ghcr.io/nvidia/jax:triton

























ghcr.io/nvidia/jax:upstream-t5x
























ghcr.io/nvidia/jax:t5x
























ghcr.io/nvidia/jax:upstream-pax
























ghcr.io/nvidia/jax:pax
























ghcr.io/nvidia/jax:maxtext
























ghcr.io/nvidia/jax:gemma


















In all cases, `ghcr.io/nvidia/jax:XXX` points to latest nightly build of the container for `XXX`. For a stable reference, use `ghcr.io/nvidia/jax:XXX-YYYY-MM-DD`.

In addition to the public CI, we also run internal CI tests on H100 SXM 80GB and A100 SXM 80GB.

## Environment Variables

The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is embedded with the following flags and environment variables for performance tuning of XLA and NCCL:

| XLA Flags | Value | Explanation |
| --------- | ----- | ----------- |
| `--xla_gpu_enable_latency_hiding_scheduler` | `true` | allows XLA to move communication collectives to increase overlap with compute kernels |
| `--xla_gpu_enable_triton_gemm` | `false` | use cuBLAS instead of Trition GeMM kernels |

| Environment Variable | Value | Explanation |
| -------------------- | ----- | ----------- |
| `CUDA_DEVICE_MAX_CONNECTIONS` | `1` | use a single queue for GPU work to lower latency of stream operations; OK since XLA already orders launches |
| `NCCL_NVLS_ENABLE` | `0` | Disables NVLink SHARP ([1](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-nvls-enable)). Future releases will re-enable this feature. |

There are various other XLA flags users can set to improve performance. For a detailed explanation of these flags, please refer to the [GPU performance](./rosetta/docs/GPU_performance.md) doc. XLA flags can be tuned per workflow. For example, each script in [contrib/gpu/scripts_gpu](https://github.com/google/paxml/tree/main/paxml/contrib/gpu/scripts_gpu) sets its own [XLA flags](https://github.com/google/paxml/blob/93fbc8010dca95af59ab615c366d912136b7429c/paxml/contrib/gpu/scripts_gpu/benchmark_gpt_multinode.sh#L30-L33).

For a list of previously used XLA flags that are no longer needed, please also refer to the [GPU performance](./rosetta/docs/GPU_performance.md#previously-used-xla-flags) page.

## Versions

| First nightly with new base container | Base container |
| ------------------------------------- | -------------- |
| 2024-11-06 | nvidia/cuda:12.6.2-devel-ubuntu22.04 |
| 2024-09-25 | nvidia/cuda:12.6.1-devel-ubuntu22.04 |
| 2024-07-24 | nvidia/cuda:12.5.0-devel-ubuntu22.04 |

## Profiling
See [this page](./docs/profiling.md) for more information about how to profile JAX programs on GPU.

## Frequently asked questions (FAQ)

`bus error` when running JAX in a docker container

**Solution:**
```bash
docker run -it --shm-size=1g ...
```

**Explanation:**
The `bus error` might occur due to the size limitation of `/dev/shm`. You can address this by increasing the shared memory size using
the `--shm-size` option when launching your container.

enroot/pyxis reports error code 404 when importing multi-arch images

**Problem description:**
```
slurmstepd: error: pyxis: [INFO] Authentication succeeded
slurmstepd: error: pyxis: [INFO] Fetching image manifest list
slurmstepd: error: pyxis: [INFO] Fetching image manifest
slurmstepd: error: pyxis: [ERROR] URL https://ghcr.io/v2/nvidia/jax/manifests/ returned error code: 404 Not Found
```

**Solution:**
Upgrade [enroot](https://github.com/NVIDIA/enroot) or [apply a single-file patch](https://github.com/NVIDIA/enroot/releases/tag/v3.4.0) as mentioned in the enroot v3.4.0 release note.

**Explanation:**
Docker has traditionally used Docker Schema V2.2 for multi-arch manifest lists but has switched to using the Open Container Initiative (OCI) format since 20.10. Enroot added support for OCI format in version 3.4.0.

## JAX on Public Clouds
* AWS
* [Add EFA integration](https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-efa.html)
* [SageMaker code sample](https://github.com/aws-samples/aws-samples-for-ray/tree/main/sagemaker/jax_alpa_language_model)
* GCP
* [Getting started with JAX multi-node applications with NVIDIA GPUs on Google Kubernetes Engine](https://cloud.google.com/blog/products/containers-kubernetes/machine-learning-with-jax-on-kubernetes-with-nvidia-gpus)
* Azure
* [Accelerating AI applications using the JAX framework on Azure’s NDm A100 v4 Virtual Machines](https://techcommunity.microsoft.com/t5/azure-high-performance-computing/accelerating-ai-applications-using-the-jax-framework-on-azure-s/ba-p/3735314)
* OCI
* [Running a deep learning workload with JAX on multinode multi-GPU clusters on OCI](https://blogs.oracle.com/cloud-infrastructure/post/running-multinode-jax-clusters-on-oci-gpu-cloud)

## Resources
* [JAX | NVIDIA NGC Container](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/jax)
* [Slurm and OpenMPI zero config integration](https://jax.readthedocs.io/en/latest/_autosummary/jax.distributed.initialize.html)
* [Adding custom GPU ops](https://jax.readthedocs.io/en/latest/Custom_Operation_for_GPUs.html)
* [Triaging regressions](docs/triage-tool.md)

## Videos
* [Equinox for JAX: The Foundation of an Ecosystem for Science and Machine Learning](https://www.nvidia.com/en-us/on-demand/session/gtc24-s62668/)
* [Scaling Grok with JAX and H100](https://www.nvidia.com/en-us/on-demand/session/gtc24-s63257/)
* [JAX Supercharged on GPUs: High Performance LLMs with JAX and OpenXLA](https://www.nvidia.com/en-us/on-demand/session/gtc24-s62246/)
* [What's New in JAX | GTC Spring 2024](https://www.nvidia.com/en-us/on-demand/session/gtc24-s62659/)
* [What's New in JAX | GTC Spring 2023](https://www.nvidia.com/en-us/on-demand/session/gtcspring23-s51956/)