Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/rejuvyesh/pycallchainrules.jl
Differentiate python calls from Julia
https://github.com/rejuvyesh/pycallchainrules.jl
autodifferentiation jax julia pytorch
Last synced: 20 days ago
JSON representation
Differentiate python calls from Julia
- Host: GitHub
- URL: https://github.com/rejuvyesh/pycallchainrules.jl
- Owner: rejuvyesh
- License: mit
- Created: 2022-01-20T21:53:24.000Z (almost 3 years ago)
- Default Branch: main
- Last Pushed: 2022-07-19T18:55:37.000Z (over 2 years ago)
- Last Synced: 2024-10-13T19:29:46.033Z (about 1 month ago)
- Topics: autodifferentiation, jax, julia, pytorch
- Language: Julia
- Homepage:
- Size: 245 KB
- Stars: 56
- Watchers: 5
- Forks: 2
- Open Issues: 4
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# PyCallChainRules
[![](https://img.shields.io/badge/docs-stable-blue.svg)](https://rejuvyesh.github.io/PyCallChainRules.jl/stable)
[![](https://img.shields.io/badge/docs-dev-blue.svg)](https://rejuvyesh.github.io/PyCallChainRules.jl/dev)While Julia is great, there are still a lot of existing useful differentiable python code in PyTorch, Jax, etc. Given PyCall.jl is already so great and seamless, one might wonder what it takes to differentiate through those `pycall`s. This library aims for that ideal.
Thanks to [@pabloferz](https://github.cim/pabloferz), this works on both CPU and GPU without any array copies via [DLPack.jl](https://github.com/pabloferz/DLPack.jl).
## Basic Usage
### PyTorch
#### CPU only
##### Install Python dependencies
```julia
using PyCall
run(`$(PyCall.pyprogramname) -m pip install torch==1.11.0+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html functorch`)
```##### Example
```julia
using PyCallChainRules.Torch: TorchModuleWrapper, torch
using Zygoteindim = 32
outdim = 16
torch_module = torch.nn.Linear(indim, outdim) # Can be anything subclassing torch.nn.Module
jlwrap = TorchModuleWrapper(torch_module)batchsize = 64
input = randn(Float32, indim, batchsize)
output = jlwrap(input)target = randn(Float32, outdim, batchsize)
loss(m, x, y) = sum(m(x) .- target)
grad, = Zygote.gradient(m->loss(m, input, target), jlwrap)
```#### GPU
##### Install Python dependencies
```julia
using PyCall
# For CUDA 11 and PyTorch 1.11
run(`$(PyCall.pyprogramname) -m pip install torch==1.11.0+cu113 -f https://download.pytorch.org/whl/cu113/torch_stable.html functorch`)
```##### Example
```julia
using CUDA
using PyCallChainRules.Torch: TorchModuleWrapper, torch
using Zygote@assert CUDA.functional()
indim = 32
outdim = 16
torch_module = torch.nn.Linear(indim, outdim).to(device=torch.device("cuda:0")) # Can be anything subclassing torch.nn.Module
jlwrap = TorchModuleWrapper(torch_module)batchsize = 64
input = CUDA.cu(randn(Float32, indim, batchsize))
output = jlwrap(input)target = CUDA.cu(randn(Float32, outdim, batchsize))
loss(m, x, y) = sum(m(x) .- y)
grad, = Zygote.gradient(m->loss(m, input, target), jlwrap)
```### Jax
#### CPU only
##### Install Python dependencies
```julia
using PyCall
run(`$(PyCall.pyprogramname) -m pip install jax\["cpu"\]`) # for cpu version
```##### Example
```julia
using PyCallChainRules.Jax: JaxFunctionWrapper, jax, stax, pyto_dlpackbatchsize = 64
indim = 32
outdim = 16init_lin, apply_lin = stax.Dense(outdim)
_, params = init_lin(jax.random.PRNGKey(0), (-1, indim))
params_jl = map(x->DLPack.wrap(x, pyto_dlpack), params)
jlwrap = JaxFunctionWrapper(jax.jit(apply_lin))
input = randn(Float32, indim, batchsize)
output = jlwrap(params_jl, input)target = randn(Float32, outdim, batchsize)
loss(p, x, y) = sum(jlwrap(p, x) .- y)
grad, = Zygote.gradient(p->loss(p, input, target), params_jl)
```#### GPU
##### Install Python dependencies
```julia
using PyCall
run(`$(PyCall.pyprogramname) -m pip install jax\["cuda"\] -f https://storage.googleapis.com/jax-releases/jax_releases.html`)
```##### Example
```julia
using PyCallChainRules.Jax: JaxFunctionWrapper, jax, stax
using CUDAusing PyCallChainRules.Jax: JaxFunctionWrapper, jax, stax, pyto_dlpack
batchsize = 64
indim = 32
outdim = 16init_lin, apply_lin = stax.Dense(outdim)
_, params = init_lin(jax.random.PRNGKey(0), (-1, indim))
params_jl = map(x->DLPack.wrap(x, pyto_dlpack), params)
jlwrap = JaxFunctionWrapper(jax.jit(apply_lin))
input = CUDA.cu(randn(Float32, indim, batchsize))
output = jlwrap(params_jl, input)target = CUDA.cu(randn(Float32, outdim, batchsize))
loss(p, x, y) = sum(jlwrap(p, x) .- y)
grad, = Zygote.gradient(p->loss(p, input, target), params_jl)
```When mixing `jax` and `julia` it's recommended to disable `jax`'s preallocation with setting the environment variable `XLA_PYTHON_CLIENT_PREALLOCATE=false`.
## Current Limitations
- Input and output types of wrapped python functions can only be python tensors or [nested] tuples of python tensors.
- Keyword arguments should not be arrays and do not support differentiation.