https://github.com/pynapple-org/pynajax
Jax backend for pynapple :fire:
https://github.com/pynapple-org/pynajax
jax optimisation pynapple
Last synced: 5 days ago
JSON representation
Jax backend for pynapple :fire:
- Host: GitHub
- URL: https://github.com/pynapple-org/pynajax
- Owner: pynapple-org
- License: mit
- Created: 2024-02-23T10:52:09.000Z (about 2 years ago)
- Default Branch: main
- Last Pushed: 2025-04-04T21:11:15.000Z (11 months ago)
- Last Synced: 2025-08-26T18:56:56.831Z (7 months ago)
- Topics: jax, optimisation, pynapple
- Language: Python
- Homepage: https://pynapple-org.github.io/pynajax/
- Size: 2.12 MB
- Stars: 9
- Watchers: 2
- Forks: 0
- Open Issues: 3
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# pynajax
[](https://github.com/pynapple-org/pynajax/blob/main/LICENSE)

[](https://www.repostatus.org/#active)
[](https://github.com/pynapple-org/pynajax/actions/workflows/ci.yml)
[](https://coveralls.io/github/pynapple-org/pynajax)

Welcome to `pynajax`, a GPU accelerated backend for [pynapple](https://github.com/pynapple-org/pynapple) built on top on [jax](https://github.com/google/jax). It offers a fast acceleration for the core pynapple functions using GPU.
> **Warning**
> ⚠️ This package is not meant to be used on its own. It should only be used through the pynapple API.
## Installation
Run the following `pip` command in your virtual environment.
**For macOS/Linux users:**
```bash
pip install pynajax
```
**For Windows users:**
```
python -m pip install pynajax
```
Alternatively, you can install pynapple and pynajax together.
```bash
pip install pynapple[jax]
```
## Basic usage
To use pynajax, you need to change the pynapple backend using `nap.nap_config.set_backend`. See the example below :
```python
import pynapple as nap
import numpy as np
nap.nap_config.set_backend("jax")
tsd = nap.Tsd(t=np.arange(100), d=np.random.randn(100))
# This will run on GPU or CPU depending on the jax installation
tsd.convolve(np.ones(11))
```
## Benchmark
This benchmark for the `convolve` function was run on a GPU.

See the documentation for others benchmarks.