https://github.com/avitai/datarax
A Differentiable Data Pipeline Framework for JAX
https://github.com/avitai/datarax
autograd data data-analysis data-science differentiable flax-nnx jax jit machine-learning xla
Last synced: about 2 months ago
JSON representation
A Differentiable Data Pipeline Framework for JAX
- Host: GitHub
- URL: https://github.com/avitai/datarax
- Owner: avitai
- License: mit
- Created: 2026-01-12T23:28:16.000Z (5 months ago)
- Default Branch: main
- Last Pushed: 2026-04-24T23:56:30.000Z (about 2 months ago)
- Last Synced: 2026-04-25T01:34:54.205Z (about 2 months ago)
- Topics: autograd, data, data-analysis, data-science, differentiable, flax-nnx, jax, jit, machine-learning, xla
- Language: Python
- Homepage: https://datarax.readthedocs.io/en/latest/
- Size: 8.62 MB
- Stars: 1
- Watchers: 0
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- Contributing: docs/contributing/contributing_guide.md
- License: LICENSE
Awesome Lists containing this project
README
# Datarax: A Data Pipeline Framework for JAX
[](https://github.com/avitai/datarax/actions/workflows/ci.yml)
[](https://github.com/avitai/datarax/actions/workflows/test-coverage.yml)
[](https://codecov.io/gh/avitai/datarax)
[](https://github.com/avitai/datarax/actions/workflows/build-verification.yml)
[](https://github.com/avitai/datarax/actions/workflows/summary.yml)
[](https://www.repostatus.org/#active)
---
> **Early Development - API Unstable**
>
> Datarax is in early development and undergoing rapid iteration.
> Breaking changes are expected. Pin to specific commits if stability is required.
> We recommend waiting for a stable release (v1.0) before using Datarax in production.
---
**Datarax** (*Data + Array/JAX*) is an extensible data pipeline framework built for JAX-based machine learning workflows. It leverages JAX's JIT compilation, automatic differentiation, and hardware acceleration to build data loading, preprocessing, and augmentation pipelines that run on CPUs, GPUs, and TPUs.
## Key Features
- **JAX-Native Design:** All core components built on JAX's functional paradigm with Flax NNX module system for state management
- **High Performance:** JIT-compiled pipelines via XLA, with built-in profiling and roofline analysis
- **DAG Execution Engine:** Graph-based pipeline construction with branching, parallel execution, caching, and rebatching nodes
- **Scalability:** Multi-device and multi-host data distribution with device mesh sharding
- **Determinism:** Reproducible pipelines by default using Grain's Feistel cipher shuffling (O(1) memory)
- **Extensibility:** Custom data sources, operators, and augmentation strategies via composable NNX modules
- **Benchmarking Suite:** Comparative benchmarks against 12+ frameworks with Calibrax-powered analysis and regression checks
- **Ecosystem Integration:** Works with Flax, Optax, Orbax, HuggingFace Datasets, and TensorFlow Datasets
## Why Datarax?
JAX has mature libraries for models (Flax), optimizers (Optax), and checkpointing (Orbax), but lacks a dedicated data pipeline framework that operates at the same level of abstraction. Existing options are either framework-agnostic loaders that return NumPy arrays (losing JIT/autodiff benefits) or wrappers around tf.data/PyTorch that introduce cross-framework overhead. Datarax aims to fill this gap. The framework is under active development with ongoing performance optimization — the architecture is functional, but throughput and API surface are still being refined.
### JAX-Native from the Ground Up
Every component — sources, operators, batchers, samplers, sharders — is a Flax NNX module. Pipeline state is managed through NNX's variable system, which means operators can hold learnable parameters, be serialized with Orbax, and participate in JAX transformations (`jit`, `vmap`, `grad`) without special handling.
### Differentiable Data Pipelines
Because operators are NNX modules, gradients flow through the entire pipeline. This enables approaches that are not possible with standard data loaders:
- [Gradient-based augmentation search](examples/advanced/differentiable/01_dada_learned_augmentation_guide.py) — replacing RL-based methods like AutoAugment with direct optimization
- [Task-optimized preprocessing](examples/advanced/differentiable/02_learned_isp_guide.py) — backpropagating task loss through every processing stage
- [Differentiable audio synthesis](examples/advanced/differentiable/03_ddsp_audio_synthesis_guide.py) — extending the same pattern to non-vision domains
See the [differentiable pipeline examples](docs/examples/advanced/differentiable/) for details.
### DAG Execution Model
Pipelines are directed acyclic graphs, not linear chains. The `>>` operator composes sequential steps, `|` creates parallel branches, and control-flow nodes (`Branch`, `Merge`, `SplitField`) handle conditional and multi-path logic. The DAG executor manages scheduling, caching, and rebatching across the graph.
### Deterministic Reproducibility
Shuffling uses Grain's Feistel cipher permutation, which generates a full-epoch permutation in O(1) memory without materializing the index array. Combined with explicit RNG key threading through every stochastic operator, pipelines produce identical output given the same seed — across restarts, devices, and host counts.
### Built-in Competitive Benchmarking
The benchmarking suite profiles datarax against 12+ frameworks (Grain, tf.data, PyTorch DataLoader, DALI, Ray Data, and others) across standardized scenarios. Results are converted to CalibraX runs for direction-aware metrics, regression gating, and W&B export. This benchmark-driven loop is how datarax tracks progress toward competitive throughput — current results and optimization status are tracked in the [benchmarking documentation](docs/benchmarks/index.md).
## Installation
```bash
# Basic installation
uv pip install datarax
# With data loading support (HuggingFace, TFDS, audio/image libs)
uv pip install "datarax[data]"
# With GPU support (CUDA 12)
uv pip install "datarax[gpu]"
# Full development installation
uv pip install "datarax[all]"
```
### macOS / Apple Silicon
```bash
# macOS CPU mode (recommended)
uv pip install "datarax[all-cpu]"
JAX_PLATFORMS=cpu python your_script.py
# Metal GPU acceleration (experimental, M1/M2/M3+)
uv pip install jax-metal
JAX_PLATFORMS=metal python your_script.py
```
> **Note:** Metal GPU acceleration is community-tested. CI runs on macOS with CPU only.
## Quick Start
```python
import jax
import jax.numpy as jnp
import numpy as np
from flax import nnx
from datarax import build_source_pipeline
from datarax.dag.nodes import OperatorNode
from datarax.operators import ElementOperator, ElementOperatorConfig
from datarax.sources import MemorySource, MemorySourceConfig
from datarax.typing import Element
def normalize(element: Element, key: jax.Array | None = None) -> Element:
return element.update_data({"image": element.data["image"] / 255.0})
def augment(element: Element, key: jax.Array) -> Element:
key1, _ = jax.random.split(key)
flip = jax.random.bernoulli(key1, 0.5)
new_image = jax.lax.cond(
flip, lambda img: jnp.flip(img, axis=1), lambda img: img,
element.data["image"],
)
return element.update_data({"image": new_image})
# Create in-memory data source
data = {
"image": np.random.randint(0, 255, (1000, 28, 28, 1)).astype(np.float32),
"label": np.random.randint(0, 10, (1000,)).astype(np.int32),
}
source = MemorySource(MemorySourceConfig(), data=data, rngs=nnx.Rngs(0))
# Build pipeline with DAG-based API
normalizer = ElementOperator(
ElementOperatorConfig(stochastic=False), fn=normalize, rngs=nnx.Rngs(0),
)
augmenter = ElementOperator(
ElementOperatorConfig(stochastic=True, stream_name="augmentations"),
fn=augment, rngs=nnx.Rngs(42),
)
pipeline = (
build_source_pipeline(source, batch_size=32)
>> OperatorNode(normalizer)
>> OperatorNode(augmenter)
)
# Process batches
for i, batch in enumerate(pipeline):
if i >= 3:
break
print(f"Batch {i}: images {batch['image'].shape}, labels {batch['label'].shape}")
```
### Advanced: Branching and Parallel DAGs
```python
from datarax.dag.nodes import OperatorNode, Merge, Branch
# Define additional operators
def invert(element: Element, key=None) -> Element:
return element.update_data({"image": 1.0 - element.data["image"]})
inverter = ElementOperator(
ElementOperatorConfig(stochastic=False), fn=invert, rngs=nnx.Rngs(0),
)
def is_high_contrast(element):
return jnp.var(element.data["image"]) > 0.1
# Build a complex DAG:
# 1. Source -> Batching
# 2. Parallel: normalizer AND inverter (| creates a Parallel node)
# 3. Merge: average the two branches
# 4. Branch: conditional path based on image variance
complex_pipeline = (
build_source_pipeline(source, batch_size=32)
>> (OperatorNode(normalizer) | OperatorNode(inverter))
>> Merge("mean")
>> Branch(
condition=is_high_contrast,
true_path=OperatorNode(augmenter),
false_path=OperatorNode(normalizer),
)
)
```
## Architecture
```text
src/datarax/
core/ # Base modules: DataSourceModule, OperatorModule, Element, Batcher, Sampler, Sharder
dag/ # DAG executor and node system (source, operator, batch, cache, control flow)
sources/ # MemorySource, TFDS (eager/streaming), HuggingFace (eager/streaming), ArrayRecord, MixedSource
operators/ # ElementOperator, MapOperator, CompositeOperator, modality-specific (image, text)
strategies/ # Sequential, Parallel, Branching, Ensemble, Merging execution strategies
samplers/ # Sequential, Shuffle (Feistel cipher), Range, EpochAware samplers
sharding/ # ArraySharder, JaxProcessSharder for multi-device distribution
distributed/ # DeviceMesh, DataParallel for multi-host training
batching/ # DefaultBatcher with buffer state management
checkpoint/ # NNXCheckpointHandler with Orbax integration
monitoring/ # Pipeline monitor, DAG monitor, reporters
performance/ # Roofline analysis, XLA optimization utilities
control/ # Prefetcher for asynchronous data loading
memory/ # Shared memory manager for multi-process data sharing
config/ # TOML-based configuration system with schema validation
cli/ # datarax CLI entry point
utils/ # PyTree utilities, external integration helpers
```
## Benchmarking
Datarax includes a benchmarking suite for comparison against 12+ data loading frameworks across a range of workload scenarios (vision, NLP, tabular, multimodal, distributed).
```bash
# Install benchmark dependencies (adds PyTorch, DALI, Ray, etc.)
uv sync --extra benchmark
# Optional: install CalibraX with W&B support explicitly
uv pip install "calibrax[wandb] @ git+https://github.com/avitai/calibrax.git"
# Run benchmarks locally
uv run python -m benchmarks.runners.full_runner --platform cpu --repetitions 5
# Run on cloud (SkyPilot)
sky launch benchmarks/sky/gpu-benchmark.yaml --env WANDB_API_KEY=$WANDB_API_KEY
```
Benchmark results are exported to W&B with charts, gap analysis, stability reports, and raw result artifacts. See [Benchmarking Guide](docs/benchmarks/index.md) for methodology and cloud deployment.
## Development Setup
Datarax uses `uv` as its package manager:
```bash
# Clone and setup
git clone https://github.com/avitai/datarax.git
cd datarax
# Automatic setup
./setup.sh && source activate.sh
# Or manual install
uv sync --extra dev
```
### Running Tests
```bash
# CPU-only (most stable)
JAX_PLATFORMS=cpu uv run pytest
# Include benchmark test suite in the same run
JAX_PLATFORMS=cpu uv run pytest --all-suites
# Specific module
JAX_PLATFORMS=cpu uv run pytest tests/sources/test_memory_source_module.py
```
### Docker
```bash
# Build and run
docker build -t datarax:latest .
docker run --rm --gpus all datarax:latest python -c "import datarax, jax; print(jax.devices())"
# Benchmark images
docker build -f benchmarks/docker/Dockerfile.gpu -t datarax-bench:gpu .
```
See [Docker Guide](docs/contributing/docker.md) for full details.
## Documentation
- [Installation Guide](docs/getting_started/installation.md)
- [Quick Start](docs/getting_started/quick_start.md)
- [Core Concepts](docs/getting_started/core_concepts.md)
- [User Guide](docs/user_guide/)
- [API Reference](docs/api_reference/index.md)
- [Examples](docs/examples/overview.md)
- [Benchmarking](docs/benchmarks/index.md)
- [Contributing](docs/contributing/contributing_guide.md)
- [Docker](docs/contributing/docker.md)
## License
Datarax is licensed under the [MIT License](LICENSE).