Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/gojasper/flash-diffusion
Official implementation of ⚡ Flash Diffusion ⚡: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation
https://github.com/gojasper/flash-diffusion
diffusion-models distillation dit inpainting sdxl super-resolution text-to-image
Last synced: 18 days ago
JSON representation
Official implementation of ⚡ Flash Diffusion ⚡: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation
- Host: GitHub
- URL: https://github.com/gojasper/flash-diffusion
- Owner: gojasper
- License: other
- Created: 2024-06-04T15:14:40.000Z (23 days ago)
- Default Branch: main
- Last Pushed: 2024-06-07T19:04:13.000Z (20 days ago)
- Last Synced: 2024-06-08T14:28:21.895Z (19 days ago)
- Topics: diffusion-models, distillation, dit, inpainting, sdxl, super-resolution, text-to-image
- Language: Python
- Homepage: https://gojasper.github.io/flash-diffusion-project/
- Size: 18.4 MB
- Stars: 114
- Watchers: 3
- Forks: 7
- Open Issues: 3
-
Metadata Files:
- Readme: README.md
- License: LICENSE.txt
Lists
- awesome - gojasper/flash-diffusion - Official implementation of ⚡ Flash Diffusion ⚡: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation (Python)
- awesome-stars - gojasper/flash-diffusion - Official implementation of ⚡ Flash Diffusion ⚡: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation (Python)
README
# ⚡ Flash Diffusion ⚡
This repository is the official implementation of the paper [Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation](http://arxiv.org/abs/2406.02347).
![]()
Images generated using 4 NFEs
In this paper, we propose an efficient, fast, versatile and LoRA-compatible distillation method to accelerate the generation of pre-trained diffusion models: *Flash Diffusion*. The method reaches state-of-the-art performances in terms of FID and CLIP-Score for few steps image generation on the COCO 2014 and COCO 2017 datasets, while requiring only **several GPU hours of training** and fewer trainable parameters than existing methods. In addition to its efficiency, the versatility of the method is also exposed across several tasks such as text-to-image, inpainting, face-swapping, super-resolution and using different diffusion models backbones either using a UNet-based denoisers (SD1.5, SDXL) or DiT (Pixart-α), as well as adapters. In all cases, the method allowed to reduce drastically the number of sampling steps while maintaining very high-quality image generation.
## Quick access
- [Method overview](#method)
- [Results overview](#results)
- [Installation 🛠️](#setup)
- [Text2Image model distillation](#distilling-existing-t2i-models)
- [Distilling a custom model 🚀](#example-of-a-distillation-training-with-a-custom-conditional-diffusion-model)
- [Inference with 🤗 Hugging Face pipelines](#inference-with-a-huggingface-pipeline-)
- [Citing this repository](#citation)## Method
Our method aims to create a fast, reliable, and adaptable approach for various uses. We propose to train a student model to predict in a single step a denoised multiple-step teacher prediction of a corrupted input sample. Additionally, we sample timesteps from an adaptable distribution that shifts during training to help the student model target specific timesteps.
![]()
## Results
Flash Diffusion is compatible with various backbones such as
- [Flash SD](https://huggingface.co/jasperai/flash-sd), distilled from a [SD1.5 teacher](https://huggingface.co/runwayml/stable-diffusion-v1-5)
- [Flash SDXL](https://huggingface.co/jasperai/flash-sdxl), distilled from a [SDXL teacher](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
- [Flash Pixart (DiT)](https://huggingface.co/jasperai/flash-pixart), distilled from a [Pixart-α teacher](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS)### Varying backbones for *Text-to-image*
Flash SD
![]()
Images generated using 4 NFEs
Flash SDXL
![]()
Images generated using 4 NFEs
Flash Pixart (DiT)
![]()
Images generated using 4 NFEs
### Varying Use-cases
Image-inpainting
![]()
Image-upscaling
![]()
Face-swapping
![]()
T2I-Adapters
![]()
## Setup
To be up and running, you need first to create a virtual env with at least `python3.10` installed and activate it### With `venv`
```bash
python3.10 -m venv envs/flash_diffusion
source envs/flash_diffusion/bin/activate
```### With `conda`
```bash
conda create -n flash_diffusion python=3.10
conda activate flash_diffusion
```Then install the required dependencies (if on GPU) and the repo in editable mode
```bash
pip install --upgrade pip
pip install -r requirements.txt
pip install -e .
```## Distilling existing T2I models
The main scripts to reproduce the main experiments of the paper are located in the `examples`. We provide 4 diffirent scripts:
- `train_flash_sd.py`: Distils [SD1.5 model](https://huggingface.co/runwayml/stable-diffusion-v1-5)
- `train_flash_sdxl.py`: Distils [SDXL model](https://huggingface.co/stabilityai/stable-diffusion-xl-base-1.0)
- `train_flash_pixart`: Distils [Pixart-α model](https://huggingface.co/PixArt-alpha/PixArt-XL-2-1024-MS)
- `train_flash_canny_adapter.py`: Distils a [T2I Canny Adapter](https://huggingface.co/TencentARC/t2i-adapter-canny-sdxl-1.0?library=true)In `examples\configs`, you will find the configuration `yaml` associated to each script. The only thing you need is to amend the `SHARDS_PATH_OR_URLS` section of the `yaml` so the model is trained on your own data. Please note that this package uses [`webdataset`](https://github.com/webdataset/webdataset) to handle the datastream and so the urls you use should be fomatted according to the [`webdataset format`](https://github.com/webdataset/webdataset?tab=readme-ov-file#the-webdataset-format). In particular, for those 4 examples, each sample needs to be composed of a `jpg` file containing the image and a `json` file containing the caption under the key `caption` and the image aesthetics score `aesthetic_score`:
```
sample = {
"jpg": dummy_image,
"json": {
"caption": "dummy caption",
"aesthetic_score": 6.0
}
}
```The scripts can then be launched by simply runing
```bash
# Set the number of gpus & nodes you want to use
export SLURM_NPROCS=1
export SLURM_NNODES=1# Distills SD1.5
python3.10 examples/train_flash_sd.py# Distills SDXL1.0
python3.10 examples/train_flash_sdxl.py# Distills Pixart-α
python3.10 examples/train_flash_pixart.py# Distills T2I Canny adapter
python3.10 examples/train_flash_canny_adapter.py
```## Example of a distillation training with a custom conditional diffusion model
This package is also indended to support custom model distillation.
```python
from copy import deepcopy
from flash.models.unets import DiffusersUNet2DCondWrapper
from flash.models.vae import AutoencoderKLDiffusers, AutoencoderKLDiffusersConfig
from flash.models.embedders import (
ClipEmbedder,
ClipEmbedderConfig,
ClipEmbedderWithProjection,
ConditionerWrapper,
)# Create the VAE
vae_config = AutoencoderKLDiffusersConfig(
"stabilityai/sdxl-vae" # VAE for HF Hub
)
vae = AutoencoderKLDiffusers(config=vae_config)## Create the Conditioners ##
# A Clip conditioner returning 2 types of conditioning
embedder_1_config = ClipEmbedderConfig(
version="stabilityai/stable-diffusion-xl-base-1.0", # from HF Hub
text_embedder_subfolder="text_encoder_2",
tokenizer_subfolder="tokenizer_2",
input_key="text",
always_return_pooled=True, # Return a 1-dimensional tensor
)
embeddder_1 = ClipEmbedder(config=embedder_1_config)# Embedder acting on a lr image injected in the UNET via concatenation
embedder_2_config = TorchNNEmbedderConfig(
nn_modules=["torch.nn.Conv2d"],
nn_modules_kwargs=[
dict(
in_channels=3,
out_channels=6,
kernel_size=3,
padding=1,
stride=2,
),
],
input_key="downsampled_image",
unconditional_conditioning_rate=request.param,
)
embedder_2 = TorchNNEmbedder(config=embedder_2_config)conditioner_wrapper = ConditionerWrapper(
conditioners=[embedder1, embedder2]
)# Create the Teacher denoiser
unet = DiffusersUNet2DCondWrapper(
in_channels=4 + 6, # VAE channels + concat conditioning
out_channels=4, # VAE channels
cross_attention_dim=1280, # cross-attention conditioning
projection_class_embeddings_input_dim=1280, # add conditioning
class_embed_type="projection",
)# Load the teacher weights
...# Create the student denoiser
student_denoiser = deepcopy(teacher_denoiser)
```## Inference with a Huggingface pipeline 🤗
```python
import torch
from diffusers import PixArtAlphaPipeline, Transformer2DModel, LCMScheduler
from peft import PeftModel# Load LoRA
transformer = Transformer2DModel.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
subfolder="transformer",
torch_dtype=torch.float16
)
transformer = PeftModel.from_pretrained(
transformer,
"jasperai/flash-pixart"
)# Pipeline
pipe = PixArtAlphaPipeline.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
transformer=transformer,
torch_dtype=torch.float16
)# Scheduler
pipe.scheduler = LCMScheduler.from_pretrained(
"PixArt-alpha/PixArt-XL-2-1024-MS",
subfolder="scheduler",
timestep_spacing="trailing",
)pipe.to("cuda")
prompt = "A raccoon reading a book in a lush forest."
image = pipe(prompt, num_inference_steps=4, guidance_scale=0).images[0]
```# License
This code is released under the [Creative Commons BY-NC 4.0 license](https://creativecommons.org/licenses/by-nc/4.0/legalcode.en).# Citation
If you find this work useful or use it in your research, please consider citing us```bibtex
@misc{chadebec2024flash,
title={Flash Diffusion: Accelerating Any Conditional Diffusion Model for Few Steps Image Generation},
author={Clement Chadebec and Onur Tasar and Eyal Benaroche and Benjamin Aubin},
year={2024},
eprint={2406.02347},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```