https://github.com/abdulfatir/learn-jax
Learning JAX as a PyTorch User
https://github.com/abdulfatir/learn-jax
Last synced: 6 months ago
JSON representation
Learning JAX as a PyTorch User
- Host: GitHub
- URL: https://github.com/abdulfatir/learn-jax
- Owner: abdulfatir
- License: mit
- Created: 2023-08-27T13:33:21.000Z (about 2 years ago)
- Default Branch: main
- Last Pushed: 2023-08-27T13:44:17.000Z (about 2 years ago)
- Last Synced: 2025-02-09T07:41:35.960Z (8 months ago)
- Language: Jupyter Notebook
- Size: 45.9 KB
- Stars: 2
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Learning JAX as a PyTorch User
As a PyTorch user, I found it somewhat challenging to get started with JAX due to its peculiar features (e.g., PRNG state management) and an ecosystem that may be confusing for beginners. This repo contains self-contained notebooks for deep learning models in JAX. In particular, I use the following libraries.
- JAX
- Flax: for deep learning modules, analogous to `torch.nn`.
- Optax: for optimizers, analogous to `torch.optim` .
- Distrax: for distributions, analogous to `torch.distributions`.## Models
The following models have been implemented.
- [x] Variational Autoencoder (for MNIST): [./notebooks/vae.ipynb](./notebooks/vae.ipynb)
- [ ] DDPM
- [ ] Linear Gaussian SSM
- [ ] Linear Gaussian SSM w/ Parallel Inference