Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/thummeto/forwarddiffchainrules.jl
https://github.com/thummeto/forwarddiffchainrules.jl
Last synced: about 1 month ago
JSON representation
- Host: GitHub
- URL: https://github.com/thummeto/forwarddiffchainrules.jl
- Owner: ThummeTo
- License: mit
- Created: 2022-10-10T13:30:56.000Z (over 2 years ago)
- Default Branch: master
- Last Pushed: 2024-01-02T09:36:03.000Z (about 1 year ago)
- Last Synced: 2024-11-23T12:05:14.501Z (about 2 months ago)
- Language: Julia
- Size: 37.1 KB
- Stars: 32
- Watchers: 3
- Forks: 4
- Open Issues: 3
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# ForwardDiffChainRules.jl
[![Run Tests](https://github.com/ThummeTo/ForwardDiffChainRules.jl/actions/workflows/Test.yml/badge.svg)](https://github.com/ThummeTo/ForwardDiffChainRules.jl/actions/workflows/Test.yml)
[![Coverage](https://codecov.io/gh/ThummeTo/ForwardDiffChainRules.jl/branch/master/graph/badge.svg)](https://codecov.io/gh/ThummeTo/ForwardDiffChainRules.jl)
[![ColPrac: Contributor's Guide on Collaborative Practices for Community Packages](https://img.shields.io/badge/ColPrac-Contributor's%20Guide-blueviolet)](https://github.com/SciML/ColPrac)## What is ForwardDiffChainRules.jl?
[ForwardDiff.jl](https://github.com/JuliaDiff/ForwardDiff.jl) does not support `frule`s from [ChainRulesCore.jl](https://github.com/JuliaDiff/ChainRulesCore.jl) by default. As a result, if you are creating custom AD-rules and want support for the most common AD-tools in Julia, you need to define three differentiation rules:
- `ChainRulesCore.rrule`
- `ChainRulesCore.frule`
- an additional dispatch of your function for values of type `ForwardDiff.Dual`
Technically, the last two candidates aim both for forward sensitivities, so include the same differentiation rules. This is redundant code and an error prone coding task... and not necessary anymore![ForwardDiffChainRules.jl](https://github.com/ThummeTo/ForwardDiffChainRules.jl) allows you to re-use the differentiation code defined in an existing `ChainRulesCore.frule` with only a few lines of code and without re-coding your differentiation rules.
## How can I use ForwardDiffChainRules.jl?
1\. Open a Julia-REPL, switch to package mode using `]`, activate your preferred environment.2\. Install [*ForwardDiffChainRules.jl*](https://github.com/ThummeTo/ForwardDiffChainRules.jl):
```julia-repl
(@v1) pkg> add ForwardDiffChainRules
```3\. If you want to check that everything works correctly, you can run the tests bundled with [*ForwardDiffChainRules.jl*](https://github.com/ThummeTo/ForwardDiffChainRules.jl):
```julia-repl
(@v1) pkg> test ForwardDiffChainRules
```## How can I add a dispatch for ForwardDiff based on an existing `frule`?
```julia
using ForwardDiffChainRulesfunction f1(x1, x2)
# do whatever you want to do in your function
return (x + 2y).^2
end# define your frule for function f1 as usual
function ChainRulesCore.frule((_, Δx1, Δx2), ::typeof(f1), x1, x2)
# this could be any code you want of course
return f1(x1, x2), Δx1 + Δx2
end# create a ForwardDiff-dispatch for scalar type `x1` and `x2`
@ForwardDiff_frule f1(x1::ForwardDiff.Dual, x2::ForwardDiff.Dual)# create a ForwardDiff-dispatch for vector type `x1` and `x2`
@ForwardDiff_frule f1(x1::AbstractVector{<:ForwardDiff.Dual}, x2::AbstractVector{<:ForwardDiff.Dual})# create a ForwardDiff-dispatch for matrix type `x1` and `x2`
@ForwardDiff_frule f1(x1::AbstractMatrix{<:ForwardDiff.Dual}, x2::AbstractMatrix{<:ForwardDiff.Dual})
```## Acknowledgement
This package is based on code from Mohamed Tarek (@mohamed82008) in his package [NonconvexUtils.jl](https://github.com/JuliaNonconvex/NonconvexUtils.jl). The initial discussion started on [discourse.julialang.org](https://discourse.julialang.org/t/chainrulescore-and-forwarddiff/61705). With the aim of providing this functionality as light-weigth as possible, this package was created.