{"id":15977702,"url":"https://github.com/suned/dependent-jax","last_synced_at":"2025-03-03T02:13:26.082Z","repository":{"id":45328844,"uuid":"440310493","full_name":"suned/dependent-jax","owner":"suned","description":null,"archived":false,"fork":false,"pushed_at":"2021-12-20T21:34:10.000Z","size":22,"stargazers_count":0,"open_issues_count":0,"forks_count":0,"subscribers_count":1,"default_branch":"main","last_synced_at":"2025-01-13T12:50:10.086Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":null,"language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/suned.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2021-12-20T21:24:32.000Z","updated_at":"2021-12-20T21:34:11.000Z","dependencies_parsed_at":"2022-09-16T11:51:40.733Z","dependency_job_id":null,"html_url":"https://github.com/suned/dependent-jax","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/suned%2Fdependent-jax","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/suned%2Fdependent-jax/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/suned%2Fdependent-jax/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/suned%2Fdependent-jax/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/suned","download_url":"https://codeload.github.com/suned/dependent-jax/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":241596276,"owners_count":19988044,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2024-10-07T23:01:41.915Z","updated_at":"2025-03-03T02:13:26.062Z","avatar_url":"https://github.com/suned.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# dependent-jax\n\nProof-of-concept implementation of dependent types for statically verifiable n-dimensional array operations with `jax` and `numpy` \nby way of a [stubs only package](https://www.python.org/dev/peps/pep-0561/#stub-only-packages) \nand [mypy plugin](https://mypy.readthedocs.io/en/stable/extending_mypy.html#extending-mypy-using-plugins).\n\nNote that this is very much a work in progress, and at present only a handful of operations are supported as a basic\nproof-of-concept.\n\n## What Is This?\nIn most type systems there is a bright line between _types_ and _values_. Values are\nthe stuff you assign to variables, e.g:\n- `42`\n- `\"the string\"\"`\n\nTypes on the other hand are _sets of values_ that you talk about with your type-checker\nthrough type annotations and inference. Examples of types are:\n- `int` (to which the value `42` belongs)\n- `str` (to which the value `\"the string\"` belongs)\n\n[Dependent types](https://en.wikipedia.org/wiki/Dependent_type) blurs the line between values and types by allowing you to talk about values with your type checker. In Python\nthis is done using the \"[Literal](https://docs.python.org/3/library/typing.html#typing.Literal)\" type:\n```python\nfrom typing import Literal\n\n\nFourtyTwo = Literal[42]  # Type alias for a type (i.e a set) that only contains the value 42\nx: FourtyTwo = 42\ny: FourtyTwo = 0  # Type error because 0 does not belong to the type 42\n```\n\n`dependent-jax` is a proof-of-concept of how to use `Literal` to annotate `jax.numpy.DeviceArray` and `numpy.ndarray` types\nwith shape information, thereby providing _static verification of tensor operations_. In other words, `dependent-jax` helps `mypy`\nto catch many errors related to tensor shape mismatch that would otherwise turn up as\nruntime errors.\n\n\n`dependent-jax` currently demonstrates feasibility of the following types of annotations/inferences:\n\n- Annotating array types with shape information\n- Inferring shapes of arrays returned from functions that accept a `shape` parameter\n- Checking array shape compatibility and inferring shapes of arrays returned from binary broadcasting operations\n- Inferring shapes of arrays returned from unary operations\n- Checking array shape compatibility and inferring shapes of arrays returned from matrix multiplication\n- Inferring shapes of arrays returned from un-parameterized shape manipulation (e.g `array.flatten()`)\n- Inferring shapes of arrays returned from parameterized shape manipulation (e.g `array.reshape((...))`)\n- Checking argument compatibility and inferring shapes of arrays returned from index operations\n\nIt should be possible to extend each of the approaches described above to many similar functions/methods\nin the `jax`/`numpy` api with little effort.\n## Install\nFrom github, e.g using `pip`:\n```commandline\npip install git+https://github.com/suned/dependent-jax\n```\nAdd the following to your [mypy config file](https://mypy.readthedocs.io/en/stable/config_file.html) to enable the mypy plugin (this package doesn't make any sense without it):\n```\n[mypy]\nplugins = dependent_jax\n```\n## Usage\nWhen instantiating arrays from io or from Python values (e.g `list` instances), there\nis no way to infer the array shape, and it should be supplied via annotation. `jax.numpy.DeviceArray` and `numpy.ndarray` accepts at\nminimum two type paramaters. All type parameters to `jax.numpy.DeviceArray` and `numpy.ndarray` except the last must be `Literal` integer types. The last type parameter is always the scalar type of the array:\n```python\nfrom typing import Literal\n\nimport jax.numpy as jnp\nimport numpy as np\n\n\na: jnp.DeviceArray[Literal[3], Literal[2], jnp.float32] = jnp.array([[1, 2], [3, 4], [5, 6]])\nb: np.ndarray[Literal[3], Literal[2], np.float64] = np.array([[1, 2], [3, 4], [5, 6]])\n\nreveal_type(a)  # note: Revealed type is \"jax.numpy.DeviceArray[Literal[3], Literal[2], jax.numpy.float32]\"\nreveal_type(b)  # note: Revealed type is \"numpy.ndarray[Literal[3], Literal[2], numpy.float64]\"\n```\n\n`typing.Any` in the place of the shape variable(s) always indicates an array of unknown shape:\n\n```python\nimport jax.numpy as jnp\n\n\nreveal_type(jnp.array([]))  # note: Revealed type is \"jax.numpy.DeviceArray[Any, jax.numpy.float32]\"\n```\n\nWhen instantiating arrays with functions that take a shape parameter,\nthe resulting shape can be inferred provided that the shape arguments are\nliteral types:\n```python\nimport jax.numpy as jnp\n\n\na = jnp.zeros((2, 2))\nreveal_type(a)  # Revealed type is: jax.numpy.DeviceArray[Literal[2], Literal[2], jax.numpy.float32)\n```\n\nWith `mypy`, values can be interpreted as literal types when:\n- The value is supplied directly as an argument (e.g `jnp.zeros((2, 2))`)\n- A variable is annotated with `Literal` (e.g `two: Literal[2] = 2`)\n- A variable is annotated with `Final` (e.g `two: Final = 2`)\n\nThis means that the return type of `jnp.zeros` can be inferred in the following examples:\n\n```python\nfrom typing import Literal, Final\n\nimport jax.numpy as jnp\n\n\na: Literal[2] = 2\nb: Final = 2\n\njnp.zeros((2, 2))\njnp.zeros((a, a))\njnp.zeros((b, b))\n```\n\nbut not in:\n\n```python\nimport jax.numpy as jnp\n\n\na = 2\njnp.zeros((a, a))\n```\n\nThe shape of arrays resulting from operations on arrays with known shape can be inferred, and errors\nresulting from incompatible dimensions will be reported by `mypy`:\n\n```python\nimport jax.numpy as jnp\n\n\na: jnp.DeviceArray[Literal[3], Literal[2], jnp.float32]\nb: jnp.DeviceArray[Literal[2], Literal[1], jnp.float32]\n\nreveal_type(a @ b)  # Revealed type is: jax.numpy.DeviceArray[Literal[3], Literal[1], np.float32]\n```\nThe shape of arrays resulting from index operations can currently only be inferred when\nthe types of arguments are either:\n\n- Literal integers\n- In-line slice expressions\n- `Tuple` types with literal integer element types in the case of [advanced indexing](https://numpy.org/doc/stable/reference/arrays.indexing.html#advanced-indexing)\n\nFor example in:\n```python\nfrom typing import Final\n\nimport jax.numpy as jnp\n\n\nzero: Final = 0\na = jnp.zeros((3, 2))\n\nreveal_type(a[0])    # Revealed type is: jax.numpy.DeviceArray[Literal[2], jax.numpy.float32] \nreveal_type(a[zero:2])  # Revealed type is: jax.numpy.DeviceArray[Literal[2], Literal[2], jax.numpy.float32]\n```\n\nBut not in:\n```python\ns = slice(0, 1)\n# Inference of index operations with slices only works with in-line slice expressions\nreveal_type(a[s])    # Revealed type is: jax.numpy.DeviceArray[Any, jax.numpy.float32]\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fsuned%2Fdependent-jax","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fsuned%2Fdependent-jax","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fsuned%2Fdependent-jax/lists"}