An open API service indexing awesome lists of open source software.

https://github.com/robinhenry/crn-jax

Chemical reaction networks in JAX: a tiny, GPU-optimized Gillespie/SSA simulation library.
https://github.com/robinhenry/crn-jax

chemical-reaction-networks gillespie gpu jax ssa stochastic-simulation systems-biology

Last synced: about 1 month ago
JSON representation

Chemical reaction networks in JAX: a tiny, GPU-optimized Gillespie/SSA simulation library.

Awesome Lists containing this project

README

          

# crn-jax

[![CI](https://github.com/robinhenry/crn-jax/actions/workflows/ci.yml/badge.svg)](https://github.com/robinhenry/crn-jax/actions/workflows/ci.yml)
[![PyPI](https://img.shields.io/pypi/v/crn-jax.svg)](https://pypi.org/project/crn-jax/)
[![Python](https://img.shields.io/pypi/pyversions/crn-jax.svg)](https://pypi.org/project/crn-jax/)
[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
[![Ruff](https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/ruff/main/assets/badge/v2.json)](https://github.com/astral-sh/ruff)

Chemical reaction networks in JAX โ€” a tiny, GPU-optimized Gillespie / Stochastic Simulation Algorithm (SSA) library.





Birth-death benchmark: crn-jax on GPU vs GillesPy2 (C++) on CPU.

ย 



Linear-cascade benchmark: crn-jax on GPU vs GillesPy2 (C++) on CPU.


Wall-time to simulate 1,000,000 independent stochastic trajectories โ€” each a full Gillespie run of the reaction network from t=0 to t=20, sampled at 200 time points (CPU vs RTX 5090 GPU).

## Install

```bash
pip install crn-jax

# with NVIDIA GPU support:
pip install crn-jax "jax[cuda12]"

# with plotting helpers:
pip install "crn-jax[examples]"

# for local development (uses Poetry):
git clone https://github.com/robinhenry/crn-jax && cd crn-jax
poetry install # main deps + dev tools
poetry install --with gpu # add jax[cuda12] on an NVIDIA host
```

`crn-jax` depends on `jax` / `jaxlib` only.

## Key features

- ๐ŸŽฏ **Exact SSA** โ€” pure-JAX implementation of the Gillespie algorithm for chemical reaction networks.
- โšก **JIT-compiled** โ€” the entire loop compiles under `jax.jit`.
- ๐Ÿš€ **GPU speedup** โ€” 1M+ independent trajectories on a single GPU under `jax.vmap`, with no Python overhead.
- โฑ๏ธ **Discretization-safe** โ€” pending reaction times are preserved across simulation-interval boundaries, so trajectories are physically correct under discrete observations (or fixed-interval stepping).
- ๐ŸŽ›๏ธ **Control-input aware** โ€” propensities take an optional `input` argument that can vary per-interval and per-replicate, so each of N parallel trajectories can follow its own control schedule (useful for RL-style rollouts, closed-loop experiments with per-replicate inputs, โ€ฆ).
- ๐Ÿงฉ **Bring-your-own state** โ€” the loop operates on any PyTree (NamedTuple, Flax struct dataclass, Equinox module, โ€ฆ).

## Quickstart

A 1-species birth-death process, `โˆ… โ†’ X` at rate ฮป and `X โ†’ โˆ…` at rate ฮผยทx, simulated for 10 independent replicates and plotted:

```python
from typing import NamedTuple
import jax, jax.numpy as jnp
from crn_jax import simulate_trajectory, plot_trajectories

BIRTH_RATE, DEATH_RATE = 3.0, 0.1 # steady-state mean ฮป/ฮผ = 30

# Define a state-holding object
class State(NamedTuple):
time: jax.Array
x: jax.Array
next_reaction_time: jax.Array # carried across intervals

# Return propensity equations as an array
# with an optional external input (unused here)
def propensities(s, _input):
return jnp.array([BIRTH_RATE, DEATH_RATE * s.x])

# Describe how the state changes when reaction `j` fires
def apply_reaction(s, j):
return s._replace(x=s.x + jnp.where(j == 0, 1.0, -1.0))

# Initial state
state0 = State(jnp.array(0.0), jnp.array(0.0), jnp.array(jnp.inf))

@jax.jit
@jax.vmap
def run_one(key):
return simulate_trajectory(
key=key,
initial_state=state0,
timestep=1.0,
n_steps=200,
# Pass our 2 custom functions defined above
compute_propensities_fn=propensities,
apply_reaction_fn=apply_reaction,
)

# Simulate 100 Gillespie trajectories
states = run_one(jax.random.split(jax.random.PRNGKey(0), 100))
times = jnp.arange(1, 201) * 1.0
```

See the [examples](examples/) folder for more detailed examples.

## API

```python
# Main entry point: scan n_steps fixed-length intervals, stack the per-step states.
from crn_jax import simulate_trajectory

# Finer control: one interval at a time (RL-style), or until an absolute time.
from crn_jax.gillespie import simulate_interval, simulate_until

# Plotting helper: step-plots a single trajectory or an (N, T) ensemble.
from crn_jax import plot_trajectories

# Optional kinetic-law helpers.
from crn_jax.kinetics import hill_function, sample_lognormal
```

| function | when to reach for it |
| --------------------- | ---------------------------------------------------------------------------- |
| `simulate_trajectory` | You want a full trajectory on a fixed sampling grid. Start here. |
| `simulate_interval` | You're driving the system yourself, one step at a time (e.g. an RL rollout). |
| `simulate_until` | You need a custom state shape or a non-uniform time grid. Fully generic. |
| `plot_trajectories` | Quick look at the output. |

## See Also

* [GillesPy2](https://github.com/StochSS/GillesPy2): C++ optimized Gillespie simulations on CPU.
* [myriad-jax](https://github.com/robinhenry/myriad-jax): RL-style decision making fully in JAX, powered by `grn-jax` at its core.