Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/const-ae/einsum

Einstein Summation for Arrays in R
https://github.com/const-ae/einsum

Last synced: about 2 months ago
JSON representation

Einstein Summation for Arrays in R

Awesome Lists containing this project

README

        

---
output: github_document
---

```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.path = "man/figures/README-",
out.width = "100%"
)
set.seed(1)
```

# einsum

[![R build status](https://github.com/const-ae/einsum/workflows/R-CMD-check/badge.svg)](https://github.com/const-ae/einsum/actions)
[![Codecov test coverage](https://codecov.io/gh/const-ae/einsum/branch/master/graph/badge.svg)](https://app.codecov.io/gh/const-ae/einsum?branch=master)

[Einstein summation](https://en.wikipedia.org/wiki/Einstein_notation) is a concise mathematical notation that
implicitly sums over repeated indices of n-dimensional arrays. Many ordinary
matrix operations (e.g., transpose, matrix multiplication, scalar product, 'diag()', trace, etc.)
can be written using Einstein notation. The notation is particularly convenient for
expressing operations on arrays with more than two dimensions because the
respective operators ('tensor products') might not have a standardized name.

## Installation

You can install the package from [CRAN](https://CRAN.R-project.org/package=einsum) with:
``` r
install.packages("einsum")
```
or if you want to use the development version from [GitHub](https://github.com/einsum):
``` r
# install.packages("devtools")
devtools::install_github("const-ae/einsum")
```

## Example

Load the package:

```{r}
library(einsum)
```

Let's make two matrices:

```{r}
mat1 <- matrix(rnorm(n = 8 * 4), nrow = 8, ncol = 4)
mat2 <- matrix(rnorm(n = 4 * 4), nrow = 4, ncol = 4)
```

We can use `einsum()` to calculate the matrix product
```{r}
einsum("ij, jk -> ik", mat1, mat2)
```
which produces the same as the standard matrix multiplication
```{r}
mat1 %*% mat2
```
The matrix multiplication example is straightforward, and there is little benefit of using the Einstein notation over the more familiar matrix product expression. Furthermore, 'einsum' is a lot slower.

However, 'einsum' truly shines when working with more than 2-dimensional arrays, where it can be difficult to figure out the correct kind of tensor product:
```{r}
# Make three n-dimensional arrays
arr1 <- array(rnorm(3 * 9 * 2), dim = c(3, 9, 2))
arr2 <- array(rnorm(2 * 5), dim = c(2, 5))
arr3 <- array(rnorm(9 * 3), dim = c(9, 3))
# Sum across axes a, b, and c
einsum("abc, cd, ba -> d", arr1, arr2, arr3)
```
The equivalent expression using tensor products (which are not intuitive) would look like this:
```{r}
tensor::tensor(tensor::tensor(arr1, arr2, alongA = 3, alongB = 1), arr3, alongA = c(2,1), alongB = c(1, 2))
```

If you need to do the same computation repeatedly, you can use `einsum_generator()`, which generates and compiles an efficient C++ function for that calculation (to see the function code, set `compile_function=FALSE`). It can take a few seconds to compile the function, but the returned function can be one or two orders of magnitude faster than `einsum()`.

```{r}
# einsum_generator returns a function
array_prod <- einsum_generator("abc, cd, ba -> d")
array_prod(arr1, arr2, arr3)
```

```{r}
bench::mark(
tensor = tensor::tensor(tensor::tensor(arr1, arr2, alongA = 3, alongB = 1),
arr3, alongA = c(2,1), alongB = c(1, 2)),
einsum = einsum("abc, cd, ba -> d", arr1, arr2, arr3),
einsum_generator = array_prod(arr1, arr2, arr3)
)
```

Lastly, you can also generate C++ code if you need an efficient implementation of some function, which you could (with proper credit) for example paste into your R package:
```{r}
# The C++ code underlying the tensor product
cat(einsum_generator("abc, cd, ba -> d", compile_function = FALSE))
```

# Credit

This package is inspired by the [einsum](https://numpy.org/doc/stable/reference/generated/numpy.einsum.html) function in NumPy.