https://github.com/mcabbott/tensorgrad.jl
https://github.com/mcabbott/tensorgrad.jl
Last synced: 6 months ago
JSON representation
- Host: GitHub
- URL: https://github.com/mcabbott/tensorgrad.jl
- Owner: mcabbott
- Created: 2019-08-10T11:01:26.000Z (almost 7 years ago)
- Default Branch: master
- Last Pushed: 2021-03-20T00:05:45.000Z (over 5 years ago)
- Last Synced: 2024-10-13T19:28:30.620Z (over 1 year ago)
- Language: Julia
- Size: 13.7 KB
- Stars: 4
- Watchers: 2
- Forks: 1
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# TensorGrad.jl
[](https://travis-ci.org/mcabbott/TensorGrad.jl)
This package adds gradient definitions for [Zygote.jl](https://github.com/FluxML/Zygote.jl)
to most calculations using [TensorOperations.jl](https://github.com/Jutho/TensorOperations.jl),
and some using [Einsum.jl](https://github.com/ahwillia/Einsum.jl).
It exports a macro `@grad` which rewrites an expression like
```julia
@grad @tensor A[i,k] := B[i,j] * C[j,k] * D[l,l]
```
into something equivalent to this:
```julia
fun(b,c,d) = @tensor a[i,k] := b[i,j] * c[j,k] * d[l,l] # define a function
@adjoint function fun(b,c,d)
fwd = @tensor a[i,k] := b[i,j] * c[j,k] * d[l,l] # forward pass
function back(Δa)
@tensor Δb[i,j] := Δa[i,k] * c[j,k] * d[l,l] # reverse pass
@tensor Δc[j,k] := b[i,j] * Δa[i,k] * d[l,l]
δ = Diagonal(ones(size(d,1)))
@tensor Δd[l,l′] := b[i,j] * c[j,k] * Δa[i,k] * δ[l,l′]
return (Δb, Δc, Δd)
end
return (fwd, back)
end
A = fun(B,C,D) # apply this to B, C, D
```
You may also write `@grad B C @tensor A[i,k] := B[i,j] * C[j,k] * D[l,l]` to specify that
only sensitivities for `B` and `C` are needed, this will remove the calculation
of `Δd` above.
To see what is being defined, call `TensorGrad.verbose(true)` before the macro
(rather than using `@macroexpand1`).
If [Tracker.jl](https://github.com/FluxML/Tracker.jl) is loaded, then it will now
define the same gradients for `B::TrackedArray` etc.
Note that this is a fairly crude experiment, probably not something to rely on.
### Limitations:
1. The expression must be one term, and scalar factors are not handled yet.
2. It makes no attempt to cache intermediate contractions for re-use,
and thus if there are many tensors it will do the same work several times
(like `b[i,j] * c[j,k]` above, done twice).
3. Requires you to add `@grad` everywhere, so won't work in other people's code.
I can solve 1. But 2 seems hard to solve with this design.
It now understands other macros like `@einsum` which share the same syntax.
This allows it to treat non-Einstein contractions, such as batched matrix multiplication:
```julia
@grad x @einsum z[i,k,b] := x[i,j,b] * y[j,k,b]
```
Those are also handled by `@ein` from [OMEinsum.jl](https://github.com/under-Peter/OMEinsum.jl),
which may be pointless as that has its own gradients built-in.
Probably you should use that instead!
An earlier attempt is now [TensorTrack.jl](https://github.com/mcabbott/TensorTrack.jl), which works at the level of
functions `contract!` etc, and thus gets some re-use, 4.
But is completely limited by 2, being deeply plugged into TensorOperations.
Finally, note also that [TensorCast.jl](https://github.com/mcabbott/TensorCast.jl) should
be almost fully differentiable (although focused on operations other than contractions).
--- Michael Abbott, August 2019
### Update:
Essentially the same code has been bolted onto [Tullio.jl](https://github.com/mcabbott/Tullio.jl)
originally in [PR#6](https://github.com/mcabbott/Tullio.jl/pull/6), and moved to `@tensor` in [PR#92](https://github.com/mcabbott/Tullio.jl/pull/92). It has the same limitations as above.
(But it avoids `eval` by attaching gradients to a callable struct `Eval` always, not to the newly defined functions.)
The package [TensorRules.jl](https://github.com/ho-oto/TensorRules.jl) has a macro `@∇` which performs
manipulations of `@tensor` expressions, acting on whole functions containing them.