Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/crowsonkb/v-diffusion-jax
v objective diffusion inference code for JAX.
https://github.com/crowsonkb/v-diffusion-jax
Last synced: 2 months ago
JSON representation
v objective diffusion inference code for JAX.
- Host: GitHub
- URL: https://github.com/crowsonkb/v-diffusion-jax
- Owner: crowsonkb
- License: mit
- Created: 2021-11-07T12:54:28.000Z (over 3 years ago)
- Default Branch: master
- Last Pushed: 2022-04-14T17:59:08.000Z (almost 3 years ago)
- Last Synced: 2024-12-10T01:02:39.452Z (2 months ago)
- Language: Python
- Homepage:
- Size: 26.4 KB
- Stars: 211
- Watchers: 12
- Forks: 18
- Open Issues: 4
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# v-diffusion-jax
v objective diffusion inference code for JAX, by Katherine Crowson ([@RiversHaveWings](https://twitter.com/RiversHaveWings)) and Chainbreakers AI ([@jd_pressman](https://twitter.com/jd_pressman)).
The models are denoising diffusion probabilistic models (https://arxiv.org/abs/2006.11239), which are trained to reverse a gradual noising process, allowing the models to generate samples from the learned data distributions starting from random noise. DDIM-style deterministic sampling (https://arxiv.org/abs/2010.02502) is also supported. The models are also trained on continuous timesteps. They use the 'v' objective from Progressive Distillation for Fast Sampling of Diffusion Models (https://openreview.net/forum?id=TIdIXIpzhoI).
Thank you to Google's [TPU Research Cloud](https://sites.research.google/trc/about/) and [stability.ai](https://www.stability.ai) for compute to train these models!
## Dependencies
- JAX ([installation instructions](https://github.com/google/jax#installation))
- dm-haiku, einops, numpy, optax, Pillow, tqdm (install with `pip install`)
- CLIP_JAX (https://github.com/kingoflolz/CLIP_JAX), and its additional pip-installable dependencies: ftfy, regex, torch, torchvision (it does not need GPU PyTorch). **If you `git clone --recursive` this repo, it should fetch CLIP_JAX automatically.**
## Model checkpoints:
- [Danbooru SFW 128x128](https://the-eye.eu/public/AI/models/v-diffusion/danbooru_128.pkl), SHA-256 `8551fe663dae988e619444efd99995775c7618af2f15ab5d8caf6b123513c334`
- [ImageNet 128x128](https://the-eye.eu/public/AI/models/v-diffusion/imagenet_128.pkl), SHA-256 `4fc7c817b9aaa9018c6dbcbf5cd444a42f4a01856b34c49039f57fe48e090530`
- [WikiArt 128x128](https://the-eye.eu/public/AI/models/v-diffusion/wikiart_128.pkl), SHA-256 `8fbe4e0206262996ff76d3f82a18dc67d3edd28631d4725e0154b51d00b9f91a`
- [WikiArt 256x256](https://the-eye.eu/public/AI/models/v-diffusion/wikiart_256.pkl), SHA-256 `ebc6e77865bbb2d91dad1a0bfb670079c4992684a0e97caa28f784924c3afd81`
## Sampling
### Example
If the model checkpoints are stored in `checkpoints/`, the following will generate an image:
```
./clip_sample.py "a friendly robot, watercolor by James Gurney" --model wikiart_256 --seed 0
```If they are somewhere else, you need to specify the path to the checkpoint with `--checkpoint`.
### Unconditional sampling
```
usage: sample.py [-h] [--batch-size BATCH_SIZE] [--checkpoint CHECKPOINT] [--eta ETA] --model
{danbooru_128,imagenet_128,wikiart_128,wikiart_256} [-n N] [--seed SEED]
[--steps STEPS]
````--batch-size`: sample this many images at a time (default 1)
`--checkpoint`: manually specify the model checkpoint file
`--eta`: set to 0 for deterministic (DDIM) sampling, 1 (the default) for stochastic (DDPM) sampling, and in between to interpolate between the two. DDIM is preferred for low numbers of timesteps.
`--init`: specify the init image (optional)
`--model`: specify the model to use
`-n`: sample until this many images are sampled (default 1)
`--seed`: specify the random seed (default 0)
`--starting-timestep`: specify the starting timestep if an init image is used (range 0-1, default 0.9)
`--steps`: specify the number of diffusion timesteps (default is 1000, can lower for faster but lower quality sampling)
### CLIP guided sampling
CLIP guided sampling lets you generate images with diffusion models conditional on the output matching a text prompt.
```
usage: clip_sample.py [-h] [--batch-size BATCH_SIZE] [--checkpoint CHECKPOINT]
[--clip-guidance-scale CLIP_GUIDANCE_SCALE] [--eta ETA] --model
{danbooru_128,imagenet_128,wikiart_128,wikiart_256} [-n N] [--seed SEED]
[--steps STEPS]
prompt
````clip_sample.py` has the same options as `sample.py` and these additional ones:
`prompt`: the text prompt to use
`--clip-guidance-scale`: how strongly the result should match the text prompt (default 1000)