Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/ho-oto/TensorRules.jl
Macros to define custom adjoints for TensorOperations.jl
https://github.com/ho-oto/TensorRules.jl
Last synced: 14 days ago
JSON representation
Macros to define custom adjoints for TensorOperations.jl
- Host: GitHub
- URL: https://github.com/ho-oto/TensorRules.jl
- Owner: ho-oto
- License: mit
- Created: 2020-10-28T20:08:38.000Z (about 4 years ago)
- Default Branch: master
- Last Pushed: 2021-08-17T01:03:25.000Z (about 3 years ago)
- Last Synced: 2024-10-18T05:40:56.906Z (26 days ago)
- Language: Julia
- Size: 101 KB
- Stars: 15
- Watchers: 1
- Forks: 0
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# TensorRules.jl
[![Build Status](https://github.com/ho-oto/TensorRules.jl/workflows/CI/badge.svg)](https://github.com/ho-oto/TensorRules.jl/actions)
[![Code Style: Blue](https://img.shields.io/badge/code%20style-blue-4495d1.svg)](https://github.com/invenia/BlueStyle)`TensorRules.jl` provides a macro `@∇` (you can type `∇` by `\nabla`), which
enable us to use automatic differentiation (AD) libraries (e.g.,
[`Zygote.jl`](https://github.com/FluxML/Zygote.jl),
[`Diffractor.jl`](https://github.com/JuliaDiff/Diffractor.jl))
with `@tensor` and `@tensoropt` macros in [`TensorOperations.jl`](https://github.com/Jutho/TensorOperations.jl).`TensorRules.jl` uses [`ChainRulesCore.jl`](https://github.com/JuliaDiff/ChainRulesCore.jl) to define custom adjoints.
So, you can use any AD libraries which supports `ChainRulesCore.jl`.## How to use
```julia
julia> using TensorOperations, TensorRules, Zygote;
julia> function foo(a, b, c) # define function with Einstein summation
# d_F = \sum_{A,B,C,D} a_{A,B,C} b_{C,D,E,F} c_{A,B,D,E}
@tensor d[F] := a[A, B, C] * b[C, D, E, F] * c[A, B, D, E]
return d[1]
end;
julia> a, b, c = randn(3, 4, 5), randn(5, 6, 7, 8), randn(3, 4, 6, 7);
julia> gradient(foo, a, b, c); # try to obtain gradient of `foo` by Zygote
ERROR: this intrinsic must be compiled to be called
Stacktrace:
...
julia> @∇ function foo(a, b, c) # use @∇
@tensor d[F] := a[A, B, C] * b[C, D, E, F] * c[A, B, D, E]
return d[1]
end;
julia> gradient(foo, a, b, c); # it works!
```## How it works
The strategy of `TensorRules.jl` are very similar to [`TensorGrad.jl`](https://github.com/mcabbott/TensorGrad.jl).
`@∇` converts functions which contains `@tensor` or `@tensoropt` macro.
First, `@∇` detects `@tensor` or `@tensoropt` expressions in function definition
and convert them to inlined functions.
Then, `@∇` define custom adjoint rules for the generated functions.For example, the following definition
```julia
@∇ function foo(a, b, c, d, e, f)
@tensoropt !C x[A, B] := conj(a[A, C]) * sin.(b)[C, D] * c.d[D, B] + d * e[1, 2][A, B]
x = x + f
@tensor x[A, B] += a[A, C] * (a * a)[C, B]
return x
end
```will be converted to a code equivalent to
```julia
function foo(a, b, c, d, e, f)
x = _foo_1(a, sin.(a), c.d, d, e[1, 2])
x = x + f
x += _foo_2(a, a * a)
return x
end@inline _foo_1(x1, x2, x3, x4, x5) =
@tensoropt !C _[A, B] := conj(x1[A, C]) * x2[C, D] * x3[D, B] + x4 * x5[A, B]@inline _foo_2(x1, x2) = @tensor _[A, B] := x1[A, C] * x2[C, B]
function rrule(::typeof(_foo_1), x1, x2, x3, x4, x5)
f = _foo_1(x1, x2, x3, x4, x5)
Px1, Px2, Px3, Px4, Px5 = ProjectTo(x1), ProjectTo(x2), ProjectTo(x3), ProjectTo(x4), ProjectTo(x5)
function _foo_1_pullback(Δf)
fnΔx1(Δf, x1, x2, x3, x4, x5) = @tensoropt !C _[A, C] := conj(Δf[A, B]) * x2[C, D] * x3[D, B]
fnΔx1add!!(x, Δf, x1, x2, x3, x4, x5) = @tensoropt !C x[A, C] += conj(Δf[A, B]) * x2[C, D] * x3[D, B]
fnΔx2(Δf, x1, x2, x3, x4, x5) = @tensoropt !C _[C, D] := conj(conj(x1[A, C]) * conj(Δf[A, B]) * x3[D, B])
fnΔx2add!!(x, Δf, x1, x2, x3, x4, x5) = @tensoropt !C x[C, D] += conj(conj(x1[A, C]) * conj(Δf[A, B]) * x3[D, B])
fnΔx3(Δf, x1, x2, x3, x4, x5) = @tensoropt !C _[D, B] := conj(conj(x1[A, C]) * x2[C, D] * conj(Δf[A, B]))
fnΔx3add!!(x, Δf, x1, x2, x3, x4, x5) = @tensoropt !C x[D, B] += conj(conj(x1[A, C]) * x2[C, D] * conj(Δf[A, B]))
fnΔx4(Δf, x1, x2, x3, x4, x5) = first(@tensoropt !C _[] := conj(conj(Δf[A, B]) * x5[A, B]))
fnΔx5(Δf, x1, x2, x3, x4, x5) = @tensoropt !C _[A, B] := conj(x4 * conj(Δf[A, B]))
fnΔx5add!!(x, Δf, x1, x2, x3, x4, x5) = @tensoropt !C x[A, B] += conj(x4 * conj(Δf[A, B]))
Δx1 = InplaceableThunk(
Thunk(() -> Px1(fnΔx1(Δf, x1, x2, x3, x4, x5))),
x -> fnΔx1add!!(x, Δf, x1, x2, x3, x4, x5)
)
Δx2 = InplaceableThunk(
Thunk(() -> Px2(fnΔx2(Δf, x1, x2, x3, x4, x5))),
x -> fnΔx2add!!(x, Δf, x1, x2, x3, x4, x5)
)
Δx3 = InplaceableThunk(
Thunk(() -> Px3(fnΔx3(Δf, x1, x2, x3, x4, x5))),
x -> fnΔx3add!!(x, Δf, x1, x2, x3, x4, x5)
)
Δx4 = Thunk(() -> fnΔx4(Δf, x1, x2, x3, x4, x5))
Δx5 = InplaceableThunk(
Thunk(() -> Px5(fnΔx5(Δf, x1, x2, x3, x4, x5))),
x -> fnΔx5add!!(x, Δf, x1, x2, x3, x4, x5)
)
return (NoTangent(), Δx1, Δx2, Δx3, Δx4, Δx5)
end
return f, _foo_1_pullback
endfunction rrule(::typeof(_foo_2), x1, x2)
...
end
```By using `Thunk` and `InplaceableThunk` properly, adjoints will be evaluated only
if they are needed.## unsupported features
- `@∇` uses `@capture` macro defined in [`MacroTools.jl`](https://github.com/FluxML/MacroTools.jl)
to parse `Expr`. Because of the limitation of `@capture` macro,
index notations based on `:typed_vcat` and `:typed_hcat` (`A[a; b], A[a b]`)
are unsupported. Please use `A[a, b]` style.
- Designations of contraction order based on `ord=(...)` or NCON style are unsupported.
Please use `@tensoropt` and specify costs of each bonds.
- Since `Zygote.jl` does not support inplace operations, we cannot use `@tensor A[] = ...`
in the expression. Please use `:=`, `+=` and `-=` instead.## TODO
- [ ] support `@tensor` block (`@tensor begin ... end`)
- [ ] support higher order differentiation (by applying `@∇` to `rrule` and `frule` recursively)
- [ ] better support of `InplaceableThunk`