https://github.com/evhub/better_einsum
np.einsum but better
https://github.com/evhub/better_einsum
Last synced: 9 months ago
JSON representation
np.einsum but better
- Host: GitHub
- URL: https://github.com/evhub/better_einsum
- Owner: evhub
- License: apache-2.0
- Created: 2022-05-25T20:20:33.000Z (over 3 years ago)
- Default Branch: main
- Last Pushed: 2022-10-21T06:45:57.000Z (about 3 years ago)
- Last Synced: 2025-04-13T17:07:19.055Z (9 months ago)
- Language: Python
- Size: 28.3 KB
- Stars: 7
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# `better_einsum`
_[`np.einsum`](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html) but better:_
- better syntax (`"C[i,k] = A[i,j] B[j,k]"` instead of `"ij, jk -> ik"`),
- names and indices can be arbitrary variable names not just single letters,
- support for keyword arguments (`einsum("C = A[i] B[i]", A=..., B=...)`),
- warnings on common bugs, and
- an `einsum.exec` method for executing the einsum assignment in the calling scope.
`pip install better_einsum` then:
```pycon
>>> import numpy as np
>>> from better_einsum import einsum
>>> A = np.array([[1, 2], [3, 4]])
>>> B = np.array([[5, 6], [7, 8]])
>>> einsum("C[i,k] = A[i,j] * B[j,k]", A=A, B=B) # equivalent to A.dot(B)
array([[19, 22],
[43, 50]])
>>> einsum("C = A[i,j] * B[i,j]", A=A, B=B) # equivalent to np.sum(A * B)
70
>>> einsum("C[...] = A[i,...] * B[i,...]", A=A, B=B) # equivalent to np.sum(A * B, axis=0)
array([26, 44])
>>> einsum("C[i,k] = A[i,j] B[j,k]", A, B) # * is optional; positional args are also supported
array([[19, 22],
[43, 50]])
>>> einsum("C[i,k] = A[i,j] * B[j,k]", A, A) # better_einsum will catch common mistakes for you
better_einsum.py: UserWarning: better_einsum: variable 'B' in calling scope points to a different object than was passed in; this usually denotes an error
array([[ 7, 10],
[15, 22]])
>>> einsum("_[i,k] = _[i,j] * _[j,k]", A, B) # use placeholders if you don't want to name your variables
array([[19, 22],
[43, 50]])
>>> einsum.exec("C[i,k] = A[i,j] * B[j,k]") # directly assigns to C and looks up A and B
array([[19, 22],
[43, 50]])
>>> C
array([[19, 22],
[43, 50]])
>>> import jax.numpy as jnp
>>> from functools import partial
>>> jnp_einsum = partial(einsum, base_einsum_func=jnp.einsum) # better_einsum for JAX
```