An open API service indexing awesome lists of open source software.

https://github.com/patrick-kidger/jaxtyping

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/
https://github.com/patrick-kidger/jaxtyping

jax python-typing typing

Last synced: 2 months ago
JSON representation

Type annotations and runtime checking for shape and dtype of JAX/NumPy/PyTorch/etc. arrays. https://docs.kidger.site/jaxtyping/

Awesome Lists containing this project

README

          

jaxtyping

A library providing type annotations **and runtime type-checking** for the shape and dtype of JAX/PyTorch/NumPy/MLX/TensorFlow arrays and tensors.

_The name 'jax'typing is now historical, we support all of the above and have no JAX dependency!_

```python
from jaxtyping import Float
from torch import Tensor

# Accepts floating-point 2D arrays with matching axes
def matrix_multiply(x: Float[Tensor, "dim1 dim2"],
y: Float[Tensor, "dim2 dim3"]
) -> Float[Tensor, "dim1 dim3"]:
...
```

## Installation

```bash
pip install jaxtyping
```

Requires Python 3.10+.

The annotations provided by jaxtyping are compatible with runtime type-checking packages, so it is common to also install one of these. The two most popular are [typeguard](https://github.com/agronholm/typeguard) (which exhaustively checks every argument) and [beartype](https://github.com/beartype/beartype) (which checks random pieces of arguments).

## Documentation

Available at [https://docs.kidger.site/jaxtyping](https://docs.kidger.site/jaxtyping).

## See also: other libraries in the JAX ecosystem

**Always useful**
[Equinox](https://github.com/patrick-kidger/equinox): neural networks and everything not already in core JAX!

**Deep learning**
[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.
[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).
[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).
[paramax](https://github.com/danielward27/paramax): parameterizations and constraints for PyTrees.

**Scientific computing**
[Diffrax](https://github.com/patrick-kidger/diffrax): numerical differential equation solvers.
[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.
[Lineax](https://github.com/patrick-kidger/lineax): linear solvers.
[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.
[sympy2jax](https://github.com/patrick-kidger/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.
[PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)

**Awesome JAX**
[Awesome JAX](https://github.com/n2cholas/awesome-jax): a longer list of other JAX projects.