Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/mpi4jax/mpi4jax
Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python :zap:
https://github.com/mpi4jax/mpi4jax
gpu high-performance-computing jax jit mpi parallel-computing xla
Last synced: 3 months ago
JSON representation
Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python :zap:
- Host: GitHub
- URL: https://github.com/mpi4jax/mpi4jax
- Owner: mpi4jax
- License: mit
- Created: 2020-07-21T10:57:03.000Z (over 4 years ago)
- Default Branch: master
- Last Pushed: 2024-06-11T14:29:46.000Z (8 months ago)
- Last Synced: 2024-06-11T16:39:25.479Z (8 months ago)
- Topics: gpu, high-performance-computing, jax, jit, mpi, parallel-computing, xla
- Language: Python
- Homepage: https://mpi4jax.readthedocs.io/
- Size: 4.94 MB
- Stars: 383
- Watchers: 10
- Forks: 27
- Open Issues: 18
-
Metadata Files:
- Readme: README.rst
- License: LICENSE.md
- Citation: CITATION.cff
Awesome Lists containing this project
- awesome-high-performance-computing - mpi4jax - Zero-copy mpi for jax arrays (Software / Trends)
- awesome-jax - mpi4jax - Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python ⚡. <img src="https://img.shields.io/github/stars/mpi4jax/mpi4jax?style=social" align="center"> (Libraries)
- awesome-jax - mpi4jax - Zero-copy MPI communication of JAX arrays, for turbo-charged HPC applications in Python ⚡. <img src="https://img.shields.io/github/stars/mpi4jax/mpi4jax?style=social" align="center"> (Libraries)
README
mpi4jax
=======|JOSS paper| |PyPI Version| |Conda Version| |Tests| |codecov| |Documentation Status|
``mpi4jax`` enables zero-copy, multi-host communication of `JAX `_ arrays, even from jitted code and from GPU memory.
But why?
--------The JAX framework `has great performance for scientific computing workloads `_, but its `multi-host capabilities `_ are still limited.
With ``mpi4jax``, you can scale your JAX-based simulations to *entire CPU and GPU clusters* (without ever leaving ``jax.jit``).
In the spirit of differentiable programming, ``mpi4jax`` also supports differentiating through some MPI operations.
Installation
------------``mpi4jax`` is available through ``pip`` and ``conda``:
.. code:: bash
$ pip install mpi4jax # Pip
$ conda install -c conda-forge mpi4jax # condaDepending on the different jax backends you want to use, you can install mpi4jax in the following way
.. code:: bash
# pip install 'jax[cpu]'
$ pip install mpi4jax# pip install -U 'jax[cuda12]'
$ pip install cython
$ pip install mpi4jax --no-build-isolation# pip install -U 'jax[cuda12_local]'
$ CUDA_ROOT=XXX pip install mpi4jax(for more informations on jax GPU distributions, `see the JAX installation instructions `_)
In case your MPI installation is not detected correctly, `it can help to install mpi4py separately `_. When using a pre-installed ``mpi4py``, you *must* use ``--no-build-isolation`` when installing ``mpi4jax``:
.. code:: bash
# if mpi4py is already installed
$ pip install cython
$ pip install mpi4jax --no-build-isolation`Our documentation includes some more advanced installation examples. `_
Example usage
-------------.. code:: python
from mpi4py import MPI
import jax
import jax.numpy as jnp
import mpi4jaxcomm = MPI.COMM_WORLD
rank = comm.Get_rank()@jax.jit
def foo(arr):
arr = arr + rank
arr_sum, _ = mpi4jax.allreduce(arr, op=MPI.SUM, comm=comm)
return arr_suma = jnp.zeros((3, 3))
result = foo(a)if rank == 0:
print(result)Running this script on 4 processes gives:
.. code:: bash
$ mpirun -n 4 python example.py
[[6. 6. 6.]
[6. 6. 6.]
[6. 6. 6.]]``allreduce`` is just one example of the MPI primitives you can use. `See all supported operations here `_.
Community guidelines
--------------------If you have a question or feature request, or want to report a bug, feel free to `open an issue `_.
We welcome contributions of any kind `through pull requests `_. For information on running our tests, debugging, and contribution guidelines please `refer to the corresponding documentation page `_.
How to cite
-----------If you use ``mpi4jax`` in your work, please consider citing the following article:
::
@article{mpi4jax,
doi = {10.21105/joss.03419},
url = {https://doi.org/10.21105/joss.03419},
year = {2021},
publisher = {The Open Journal},
volume = {6},
number = {65},
pages = {3419},
author = {Dion Häfner and Filippo Vicentini},
title = {mpi4jax: Zero-copy MPI communication of JAX arrays},
journal = {Journal of Open Source Software}
}.. |Tests| image:: https://github.com/mpi4jax/mpi4jax/workflows/Tests/badge.svg
:target: https://github.com/mpi4jax/mpi4jax/actions?query=branch%3Amaster
.. |codecov| image:: https://codecov.io/gh/mpi4jax/mpi4jax/branch/master/graph/badge.svg
:target: https://codecov.io/gh/mpi4jax/mpi4jax
.. |PyPI Version| image:: https://img.shields.io/pypi/v/mpi4jax
:target: https://pypi.org/project/mpi4jax/
.. |Conda Version| image:: https://img.shields.io/conda/vn/conda-forge/mpi4jax.svg
:target: https://anaconda.org/conda-forge/mpi4jax
.. |Documentation Status| image:: https://readthedocs.org/projects/mpi4jax/badge/?version=latest
:target: https://mpi4jax.readthedocs.io/en/latest/?badge=latest
.. |JOSS paper| image:: https://joss.theoj.org/papers/10.21105/joss.03419/status.svg
:target: https://doi.org/10.21105/joss.03419