Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/tillahoffmann/testax
Jit-able runtime assertions for JAX in NumPy style.
https://github.com/tillahoffmann/testax
jax numpy testing
Last synced: 3 months ago
JSON representation
Jit-able runtime assertions for JAX in NumPy style.
- Host: GitHub
- URL: https://github.com/tillahoffmann/testax
- Owner: tillahoffmann
- License: apache-2.0
- Created: 2024-02-22T01:12:24.000Z (10 months ago)
- Default Branch: main
- Last Pushed: 2024-03-03T04:52:36.000Z (10 months ago)
- Last Synced: 2024-04-23T20:41:55.705Z (8 months ago)
- Topics: jax, numpy, testing
- Language: Python
- Homepage: https://testax.readthedocs.io
- Size: 44.9 KB
- Stars: 4
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.rst
- License: LICENSE
Awesome Lists containing this project
README
๐งช testax
=========.. image:: https://img.shields.io/pypi/v/testax
:target: https://pypi.org/project/testax
.. image:: https://github.com/tillahoffmann/testax/actions/workflows/build.yml/badge.svg
:target: https://github.com/tillahoffmann/testax/actions/workflows/build.yml
.. image:: https://readthedocs.org/projects/testax/badge/?version=latest
:target: https://testax.readthedocs.io/en/latest/?badge=latesttestax provides runtime assertions for JAX through the testing interface familiar to NumPy users.
>>> import jax
>>> from jax import numpy as jnp
>>> import testax
>>>
>>> def safe_log(x):
... testax.assert_array_less(0, x)
... return jnp.log(x)
>>>
>>> safe_log(jnp.arange(2))
Traceback (most recent call last):
...
jax._src.checkify.JaxRuntimeError:
Arrays are not less-orderedMismatched elements: 1 / 2 (50%)
Max absolute difference: 1
Max relative difference: 1
x: Array(0, dtype=int32, weak_type=True)
y: Array([0, 1], dtype=int32)testax assertions are :code:`jit`-able, although errors need to be functionalized to conform to JAX's requirement that functions are pure and do not have side effects (see the :code:`checkify` `guide `__ for details). In short, a :code:`checkify`-d function returns a tuple :code:`(error, value)`. The first element is an error that *may* have occurred, and the second is the return value of the original function.
>>> jitted = jax.jit(safe_log)
>>> checkified = testax.checkify(jitted)
>>> error, y = checkified(jnp.arange(2))
>>> error.throw()
Traceback (most recent call last):
...
jax._src.checkify.JaxRuntimeError:
Arrays are not less-orderedMismatched elements: 1 / 2 (50%)
Max absolute difference: 1
Max relative difference: 1
x: Array(0, dtype=int32, weak_type=True)
y: Array([0, 1], dtype=int32)
>>> y
Array([-inf, 0.], dtype=float32)Installation
------------testax is pip-installable and can be installed by running
.. code-block:: bash
pip install testax
Interface
---------testax mirrors the `testing `__ interface familiar to NumPy users, such as :code:`assert_allclose`.