https://github.com/epronovost/eincheck
Tensor shape checks inspired by einstein notation
https://github.com/epronovost/eincheck
deep-learning jax numpy pytorch tensor tensorflow
Last synced: 4 months ago
JSON representation
Tensor shape checks inspired by einstein notation
- Host: GitHub
- URL: https://github.com/epronovost/eincheck
- Owner: EPronovost
- License: mit
- Created: 2022-10-29T02:38:07.000Z (over 3 years ago)
- Default Branch: main
- Last Pushed: 2024-10-25T16:01:59.000Z (over 1 year ago)
- Last Synced: 2024-10-30T02:58:26.114Z (over 1 year ago)
- Topics: deep-learning, jax, numpy, pytorch, tensor, tensorflow
- Language: Python
- Homepage:
- Size: 149 KB
- Stars: 3
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# eincheck
[](https://github.com/epronovost/eincheck/actions/workflows/pr.yaml)
[](https://eincheck.readthedocs.io/en/main/?badge=main)
[](https://badge.fury.io/py/eincheck)
Tensor shape checks inspired by einstein notation
## Overview
This library has three main functions:
* `check_shapes` takes tuples of `(Tensor, shape)` and checks that all the Tensors match the shapes
```
check_shapes((x, "i 3"), (y, "i 3"))
```
* `check_func` is a function decorator to check the input and output shapes of a function
```
@check_func("*i x, *i y -> *i (x + y)")
def concat(a, b):
return np.concatenate([a, b], -1)
```
* `check_data` is a class decorator to check the fields of a data class
```
@check_data(start="i 2", end="i 2")
class LineSegment2D(NamedTuple):
start: torch.Tensor
end: torch.Tensor
```
For more info, [read the docs!](https://eincheck.readthedocs.io/en/main)