Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/patil-suraj/stable-diffusion-jax
https://github.com/patil-suraj/stable-diffusion-jax
Last synced: about 2 months ago
JSON representation
- Host: GitHub
- URL: https://github.com/patil-suraj/stable-diffusion-jax
- Owner: patil-suraj
- License: mit
- Created: 2022-08-12T07:04:12.000Z (over 2 years ago)
- Default Branch: main
- Last Pushed: 2022-09-19T17:15:04.000Z (over 2 years ago)
- Last Synced: 2024-10-23T11:08:46.408Z (2 months ago)
- Language: Python
- Size: 402 KB
- Stars: 86
- Watchers: 5
- Forks: 8
- Open Issues: 6
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
## TODOs:
- [x] Finish implementing the `UNet2D` model in `modeling_unte2d.py`. Port weights of any existing LDM unet from diffusers and verify equivalence. I've added the skleton of modules that we need to implement in the file.
- [x] Adapt the `PNDMScheduler` from `diffusers` for JAX: Use `jnp` arrays and make it stateless.
- [x] Add the KL module from (here)[https://github.dev/CompVis/stable-diffusion] in `modeling_vae.py` file. For inference we don't really need it, but would be nice to have for completeness. Port the weights of any existing KL VAE and verify equivalence.
- [x] Add an inference loop in `pipeline_stabel_diffusion`. We should able to `jit`/`pmap` the loop to deploy on TPUs.