https://github.com/gerlero/parajax
⚡ Automatic parallelization of calls to JAX-based functions
https://github.com/gerlero/parajax
jax parallel-computing typed
Last synced: 3 months ago
JSON representation
⚡ Automatic parallelization of calls to JAX-based functions
- Host: GitHub
- URL: https://github.com/gerlero/parajax
- Owner: gerlero
- License: apache-2.0
- Created: 2025-09-27T15:56:15.000Z (4 months ago)
- Default Branch: main
- Last Pushed: 2025-10-03T19:23:00.000Z (3 months ago)
- Last Synced: 2025-10-06T23:04:03.216Z (3 months ago)
- Topics: jax, parallel-computing, typed
- Language: Python
- Homepage: https://parajax.readthedocs.io
- Size: 212 KB
- Stars: 0
- Watchers: 0
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE.txt
Awesome Lists containing this project
README
**Automagic parallelization of calls to [JAX](https://github.com/jax-ml/jax)-based functions**
[](https://github.com/gerlero/parajax/actions/workflows/ci.yml)
[](https://codecov.io/gh/gerlero/parajax)
[](https://github.com/astral-sh/ruff)
[](https://github.com/astral-sh/ty)
[](https://github.com/astral-sh/uv)
[](https://github.com/gerlero/parajax/actions/workflows/pypi-publish.yml)
[](https://pypi.org/project/parajax/)
[](https://pypi.org/project/parajax/)
## Features
- 🚀 **Device-parallel execution**: run across multiple CPUs, GPUs or TPUs automatically
- ⚡ **Fully composable** with [`jax.jit`](https://docs.jax.dev/en/latest/_autosummary/jax.jit.html), [`jax.vmap`](https://docs.jax.dev/en/latest/_autosummary/jax.vmap.html), and other JAX transformations
- 🪄 **Automatic handling** of input shapes not divisible by the number of devices
- 🎯 **Simple interface**: just decorate your function with `autopmap`
## Installation
```bash
pip install parajax
```
## Example
```python
import multiprocessing
import jax
import jax.numpy as jnp
from parajax import autopmap
jax.config.update("jax_num_cpu_devices", multiprocessing.cpu_count())
# ^ Only needed on CPU: allow JAX to use all CPU cores
@autopmap
def square(x):
return x**2
xs = jnp.arange(97)
ys = square(xs)
```
That's it! Invocations of `square` will now be automatically parallelized across all available devices.
## Documentation
For more details, check out the [documentation](https://parajax.readthedocs.io/).