https://github.com/34j/array-api-jit
JIT decorator supporting multiple array API compatible libraries
https://github.com/34j/array-api-jit
array-api cupy jax jit numba pytorch
Last synced: 2 days ago
JSON representation
JIT decorator supporting multiple array API compatible libraries
- Host: GitHub
- URL: https://github.com/34j/array-api-jit
- Owner: 34j
- License: mit
- Created: 2025-07-11T12:36:58.000Z (7 months ago)
- Default Branch: main
- Last Pushed: 2025-09-30T21:30:17.000Z (4 months ago)
- Last Synced: 2025-09-30T23:25:51.630Z (4 months ago)
- Topics: array-api, cupy, jax, jit, numba, pytorch
- Language: Python
- Homepage:
- Size: 263 KB
- Stars: 0
- Watchers: 0
- Forks: 0
- Open Issues: 3
-
Metadata Files:
- Readme: README.md
- Changelog: CHANGELOG.md
- Contributing: CONTRIBUTING.md
- Funding: .github/FUNDING.yml
- License: LICENSE
- Code of conduct: .github/CODE_OF_CONDUCT.md
Awesome Lists containing this project
README
# array API JIT
---
**Documentation**: https://array-api-jit.readthedocs.io
**Source Code**: https://github.com/34j/array-api-jit
---
JIT decorator supporting multiple array API compatible libraries
## Installation
Install this via pip (or your favourite package manager):
```shell
pip install array-api-jit
```
## Usage
Simply decorate your function with `@jit()`:
```python
from array_api_jit import jit
@jit()
def my_function(x: Any) -> Any:
xp = array_namespace(x)
return xp.sin(x) + xp.cos(x)
```
## Advanced Usage
You can specify the decorator, arguments, and keyword arguments for each library.
```python
from array_api_jit import jit
from array_api_compat import array_namespace
from typing import Any
import numba
@jit(
{"numpy": numba.jit()}, # numba.jit is not used by default because it may not succeed
decorator_kwargs={
"jax": {"static_argnames": ["n"]}
}, # jax requires for-loop variable to be "static_argnames"
# fail_on_error: bool = False, # do not raise an error if the decorator fails (Default)
# rerun_on_error: bool = True, # re-run the original function if the wrapped function fails (NOT Default)
)
def sin_n_times(x: Any, n: int) -> Any:
xp = array_namespace(x)
for i in range(n):
x = xp.sin(x)
return x
```
## Contributors ✨
Thanks goes to these wonderful people ([emoji key](https://allcontributors.org/docs/en/emoji-key)):
This project follows the [all-contributors](https://github.com/all-contributors/all-contributors) specification. Contributions of any kind welcome!
## Credits
[](https://github.com/copier-org/copier)
This package was created with
[Copier](https://copier.readthedocs.io/) and the
[browniebroke/pypackage-template](https://github.com/browniebroke/pypackage-template)
project template.