https://github.com/galacticdynamics/unxt
Unitful Quantities in JAX
https://github.com/galacticdynamics/unxt
array auto-differentiation jax quantities units
Last synced: 6 days ago
JSON representation
Unitful Quantities in JAX
- Host: GitHub
- URL: https://github.com/galacticdynamics/unxt
- Owner: GalacticDynamics
- License: bsd-3-clause
- Created: 2023-12-22T22:13:31.000Z (about 2 years ago)
- Default Branch: main
- Last Pushed: 2024-10-29T12:33:40.000Z (about 1 year ago)
- Last Synced: 2024-10-29T14:59:11.831Z (about 1 year ago)
- Topics: array, auto-differentiation, jax, quantities, units
- Language: Python
- Homepage: https://unxt.readthedocs.io/
- Size: 1.1 MB
- Stars: 13
- Watchers: 1
- Forks: 3
- Open Issues: 16
-
Metadata Files:
- Readme: README.md
- Contributing: .github/CONTRIBUTING.md
- License: LICENSE
Awesome Lists containing this project
README
unxt
Unitful Quantities in JAX
---
Unxt is unitful quantities and calculations in [JAX][jax], built on
[Equinox][equinox] and [Quax][quax].
Unxt supports JAX's compelling features:
- JIT compilation (`jit`)
- vectorization (`vmap`, etc.)
- auto-differentiation (`grad`, `jacobian`, `hessian`)
- GPU/TPU/multi-host acceleration
And best of all, `unxt` doesn't force you to use special unit-compatible
re-exports of JAX libraries. You can use `unxt` with existing JAX code, and with
[quax][quax]'s simple decorator, JAX will work with `unxt.Quantity`.
## Installation
[![PyPI version][pypi-version]][pypi-link]
[![PyPI platforms][pypi-platforms]][pypi-link]
```bash
pip install unxt
```
using uv
```bash
uv add unxt
```
from source, using pip
```bash
pip install git+https://https://github.com/GalacticDynamics/unxt.git
```
building from source
```bash
cd /path/to/parent
git clone https://https://github.com/GalacticDynamics/unxt.git
cd unxt
pip install -e . # editable mode
```
## [Documentation][rtd-link]
[](https://unxt.readthedocs.io/en/)
### Quick example
```python
import unxt as u
import jax.numpy as jnp
x = u.Quantity(jnp.arange(1, 5, dtype=float), "km")
print(x)
# Quantity['length']([1., 2., 3., 4.], unit='km')
```
The constituent value and unit are accessible as attributes:
```python
repr(x.value)
# Array([1., 2., 3., 4.], dtype=float64)
repr(x.unit)
# Unit("m")
```
`Quantity` objects obey the rules of unitful arithmetic.
```python
# Addition / Subtraction
print(x + x)
# Quantity["length"]([2.0, 4.0, 6.0, 8.0], unit="km")
# Multiplication / Division
print(2 * x)
# Quantity["length"]([2.0, 4.0, 6.0, 8.0], unit="km")
y = u.Quantity(jnp.arange(4, 8, dtype=float), "yr")
print(x / y)
# Quantity['speed']([0.25, 0.4 , 0.5 , 0.57142857], unit='km / yr')
# Exponentiation
print(x**2)
# Quantity['area']([ 1., 4., 9., 16.], unit='km2')
# Unit checking on operations
try:
x + y
except Exception as e:
print(e)
# 'yr' (time) and 'km' (length) are not convertible
```
Quantities can be converted to different units:
```python
print(u.uconvert("m", x)) # via function
# Quantity['length']([1000., 2000., 3000., 4000.], unit='m')
print(x.uconvert("m")) # via method
# Quantity['length']([1000., 2000., 3000., 4000.], unit='m')
```
Since `Quantity` is parametric, it can do runtime dimension checking!
```python
LengthQuantity = u.Quantity["length"]
print(LengthQuantity(2, "km"))
# Quantity['length'](2, unit='km')
try:
LengthQuantity(2, "s")
except ValueError as e:
print(e)
# Physical type mismatch.
```
For static configuration values (e.g., JAX static arguments), use
`StaticQuantity`, which stores NumPy values and rejects JAX arrays:
```python
import numpy as np
from functools import partial
import jax
import jax.numpy as jnp
import unxt as u
cfg = u.StaticQuantity(np.array([1.0, 2.0]), "m")
@partial(jax.jit, static_argnames=("q",))
def add(x, q):
return x + jnp.asarray(q.value)
print(add(1.0, cfg))
```
If you want a `Quantity` that keeps a static value but still participates in
regular arithmetic, wrap the value with `StaticValue`. Arithmetic behaves like
the wrapped array, and `StaticValue + StaticValue` returns a `StaticValue`:
```python
import numpy as np
import jax.numpy as jnp
import unxt as u
sv = u.quantity.StaticValue(np.array([1.0, 2.0]))
q_static = u.Q(sv, "m")
q = u.Q(jnp.array([3.0, 4.0]), "m")
print(q_static + q)
```
`unxt` is built on [`quax`][quax], which enables custom array-ish objects in
JAX. For convenience we use the [`quaxed`][quaxed] library, which is just a
`quax.quaxify` wrapper around `jax` to avoid boilerplate code.
> [!NOTE]
>
> Using [`quaxed`][quaxed] is optional. You can directly use `quaxify`, and even
> apply it to the top-level function instead of individual functions.
```python
from quaxed import grad, vmap
import quaxed.numpy as jnp
print(jnp.square(x))
# Quantity['area']([ 1., 4., 9., 16.], unit='km2')
print(jnp.power(x, 3))
# Quantity['volume']([ 1., 8., 27., 64.], unit='km3')
print(vmap(grad(lambda x: x**3))(x))
# Quantity['area']([ 3., 12., 27., 48.], unit='km2')
```
See the [documentation][rtd-link] for more examples and details of JIT and AD
## Citation
[![JOSS][joss-badge]][joss-link] [![DOI][zenodo-badge]][zenodo-link]
If you found this library to be useful and want to support the development and
maintenance of lower-level code libraries for the scientific community, please
consider citing this work.
## Contributing and Development
[![Actions Status][actions-badge]][actions-link]
[![Documentation Status][rtd-badge]][rtd-link]
[![codecov][codecov-badge]][codecov-link]
[![SPEC 0 — Minimum Supported Dependencies][spec0-badge]][spec0-link]
[![pre-commit][pre-commit-badge]][pre-commit-link]
[![ruff][ruff-badge]][ruff-link]
[](https://codspeed.io/GalacticDynamics/unxt)
We welcome contributions! Contributions are how open source projects improve and
grow.
To contribute to `unxt`, please
[fork](https://github.com/GalacticDynamics/unxt/fork) the repository, make a
development branch, develop on that branch, then
[open a pull request](https://github.com/GalacticDynamics/unxt/compare) from the
branch in your fork to main.
To report bugs, request features, or suggest other ideas, please
[open an issue](https://github.com/GalacticDynamics/unxt/issues/new/choose).
For more information, see [CONTRIBUTING.md](CONTRIBUTING.md).
[equinox]: https://docs.kidger.site/equinox/
[jax]: https://jax.readthedocs.io/en/latest/
[quax]: https://github.com/patrick-kidger/quax
[quaxed]: https://github.com/GalacticDynamics/quaxed
[actions-badge]: https://github.com/GalacticDynamics/unxt/actions/workflows/ci.yml/badge.svg?branch=main
[actions-link]: https://github.com/GalacticDynamics/unxt/actions
[codecov-badge]: https://codecov.io/gh/GalacticDynamics/unxt/graph/badge.svg
[codecov-link]: https://codecov.io/gh/GalacticDynamics/unxt
[joss-badge]: https://joss.theoj.org/papers/10.21105/joss.07771/status.svg
[joss-link]: https://doi.org/10.21105/joss.07771
[pre-commit-badge]: https://img.shields.io/badge/pre--commit-enabled-brightgreen?logo=pre-commit
[pre-commit-link]: https://pre-commit.com
[pypi-link]: https://pypi.org/project/unxt/
[pypi-platforms]: https://img.shields.io/pypi/pyversions/unxt
[pypi-version]: https://img.shields.io/pypi/v/unxt
[rtd-badge]: https://readthedocs.org/projects/unxt/badge/?version=latest
[rtd-link]: https://unxt.readthedocs.io/en/
[ruff-badge]: https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v2.json
[ruff-link]: https://docs.astral.sh/ruff/
[spec0-badge]: https://img.shields.io/badge/SPEC-0-green?labelColor=%23004811&color=%235CA038
[spec0-link]: https://scientific-python.org/specs/spec-0000/
[zenodo-badge]: https://zenodo.org/badge/734877295.svg
[zenodo-link]: https://zenodo.org/doi/10.5281/zenodo.10850455