Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/lucidrains/jax2torch
Use Jax functions in Pytorch
https://github.com/lucidrains/jax2torch
deep-learning-framework jax torch
Last synced: 13 days ago
JSON representation
Use Jax functions in Pytorch
- Host: GitHub
- URL: https://github.com/lucidrains/jax2torch
- Owner: lucidrains
- License: mit
- Created: 2021-10-26T00:16:32.000Z (about 3 years ago)
- Default Branch: main
- Last Pushed: 2023-07-01T17:24:32.000Z (over 1 year ago)
- Last Synced: 2024-10-23T00:06:05.926Z (21 days ago)
- Topics: deep-learning-framework, jax, torch
- Language: Python
- Homepage:
- Size: 17.6 KB
- Stars: 225
- Watchers: 5
- Forks: 9
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
## jax2torch
Use Jax functions in Pytorch with DLPack, as outlined in a gist by @mattjj. The repository was made for the purposes of making this differentiable alignment work interoperable with Pytorch projects.
## Install
```bash
$ pip install jax2torch
```## Memory management
By default, Jax pre-allocates 90% of VRAM, which leaves Pytorch with very little left over. To prevent this behavior, set the `XLA_PYTHON_CLIENT_PREALLOCATE` environmental variable to false before running any Jax code:
```python
import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
```## Usage
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1GBEEnpuCvLS1bhb_xGCO5Y40rFiQrh6G?usp=sharing) Quick test
```python
import jax
import torch
from jax2torch import jax2torch
import osos.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
# Jax function
@jax.jit
def jax_pow(x, y = 2):
return x ** y# convert to Torch function
torch_pow = jax2torch(jax_pow)
# run it on Torch data!
x = torch.tensor([1., 2., 3.])
y = torch_pow(x, y = 3)
print(y) # tensor([1., 8., 27.])# And differentiate!
x = torch.tensor([2., 3.], requires_grad = True)
y = torch.sum(torch_pow(x, y = 3))
y.backward()
print(x.grad) # tensor([12., 27.])
```