An open API service indexing awesome lists of open source software.

https://github.com/evhub/better_einsum

np.einsum but better
https://github.com/evhub/better_einsum

Last synced: 9 months ago
JSON representation

np.einsum but better

Awesome Lists containing this project

README

          

# `better_einsum`

_[`np.einsum`](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html) but better:_

- better syntax (`"C[i,k] = A[i,j] B[j,k]"` instead of `"ij, jk -> ik"`),
- names and indices can be arbitrary variable names not just single letters,
- support for keyword arguments (`einsum("C = A[i] B[i]", A=..., B=...)`),
- warnings on common bugs, and
- an `einsum.exec` method for executing the einsum assignment in the calling scope.

`pip install better_einsum` then:

```pycon
>>> import numpy as np
>>> from better_einsum import einsum

>>> A = np.array([[1, 2], [3, 4]])
>>> B = np.array([[5, 6], [7, 8]])

>>> einsum("C[i,k] = A[i,j] * B[j,k]", A=A, B=B) # equivalent to A.dot(B)
array([[19, 22],
[43, 50]])

>>> einsum("C = A[i,j] * B[i,j]", A=A, B=B) # equivalent to np.sum(A * B)
70

>>> einsum("C[...] = A[i,...] * B[i,...]", A=A, B=B) # equivalent to np.sum(A * B, axis=0)
array([26, 44])

>>> einsum("C[i,k] = A[i,j] B[j,k]", A, B) # * is optional; positional args are also supported
array([[19, 22],
[43, 50]])

>>> einsum("C[i,k] = A[i,j] * B[j,k]", A, A) # better_einsum will catch common mistakes for you
better_einsum.py: UserWarning: better_einsum: variable 'B' in calling scope points to a different object than was passed in; this usually denotes an error
array([[ 7, 10],
[15, 22]])

>>> einsum("_[i,k] = _[i,j] * _[j,k]", A, B) # use placeholders if you don't want to name your variables
array([[19, 22],
[43, 50]])

>>> einsum.exec("C[i,k] = A[i,j] * B[j,k]") # directly assigns to C and looks up A and B
array([[19, 22],
[43, 50]])
>>> C
array([[19, 22],
[43, 50]])

>>> import jax.numpy as jnp
>>> from functools import partial
>>> jnp_einsum = partial(einsum, base_einsum_func=jnp.einsum) # better_einsum for JAX
```