https://github.com/XavierXiao/Dreambooth-Stable-Diffusion
Implementation of Dreambooth (https://arxiv.org/abs/2208.12242) with Stable Diffusion
https://github.com/XavierXiao/Dreambooth-Stable-Diffusion
pytorch pytorch-lightning stable-diffusion text-to-image
Last synced: 11 months ago
JSON representation
Implementation of Dreambooth (https://arxiv.org/abs/2208.12242) with Stable Diffusion
- Host: GitHub
- URL: https://github.com/XavierXiao/Dreambooth-Stable-Diffusion
- Owner: XavierXiao
- License: mit
- Created: 2022-09-06T06:57:40.000Z (over 3 years ago)
- Default Branch: main
- Last Pushed: 2022-12-08T02:19:03.000Z (about 3 years ago)
- Last Synced: 2024-10-29T15:33:25.448Z (over 1 year ago)
- Topics: pytorch, pytorch-lightning, stable-diffusion, text-to-image
- Language: Jupyter Notebook
- Homepage:
- Size: 5.71 MB
- Stars: 7,595
- Watchers: 92
- Forks: 795
- Open Issues: 137
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
- awesome-diffusion-categorized - [Unofficial Code
- awesome-stable-diffusion - Dreambooth-Stable-Diffusion - Implementation of [Google's DreamBooth](https://arxiv.org/abs/2208.12242) for stable diffusion, allowing fine-tuning of the model for specific concepts. (Training / Potentially Stale/Less active branches)
- awesome-conditional-content-generation - [Code
- awesome-stable-diffusion - Dreambooth-Stable-Diffusion - tuning for Stable Diffusion. (👑Stable Diffusion / Python)
- StarryDivineSky - XavierXiao/Dreambooth-Stable-Diffusion - Stable-Diffusion 是一个基于 Stable Diffusion 模型实现 Dreambooth 技术的开源项目,旨在通过少量样本训练生成特定对象或角色的高质量图像。该项目的核心原理是利用 Stable Diffusion 的扩散模型架构,通过微调(Dreambooth)方法让模型学习目标对象的特征,并结合文本提示生成符合要求的图像。用户只需提供少量目标对象的图片(通常5-10张)和对应的文本描述,模型就能通过训练将这些特征嵌入到扩散过程中,最终生成包含该对象的新图像。 项目支持多种训练模式,包括使用预训练的 Stable Diffusion 模型权重进行微调,或从零开始训练模型。训练过程分为三个阶段:首先准备目标对象的图像数据集,其次通过 Dreambooth 方法调整模型参数,最后使用优化后的模型生成图像。开发者提供了 Colab 笔记本作为训练工具,简化了模型训练流程,用户可直接在云端运行代码。项目还包含优化建议,例如使用 VAE(变分自编码器)提升图像质量,或调整训练参数以缩短训练时间。 该项目的技术亮点在于对 Stable Diffusion 的高效适配,允许用户通过简单的数据集和文本指令实现定制化生成。其工作原理基于扩散模型的逆向过程:通过逐步去噪生成图像,并在训练中引入目标对象的文本嵌入向量(text embeddings)来指导生成方向。项目文档中详细说明了数据准备规范、训练参数设置和生成结果的优化技巧,适合有一定机器学习基础的开发者使用。由于 Stable Diffusion 模型本身依赖大量计算资源,项目建议使用 GPU 环境运行,且训练时间可能需要数小时至数十小时不等。整体而言,该项目为 Stable Diffusion 模型的定制化应用提供了便捷的实现路径,适合图像生成、角色设计等场景。 (图像生成 / 资源传输下载)
README
# Dreambooth on Stable Diffusion
This is an implementtaion of Google's [Dreambooth](https://arxiv.org/abs/2208.12242) with [Stable Diffusion](https://github.com/CompVis/stable-diffusion). The original Dreambooth is based on [Imagen](https://imagen.research.google/) text-to-image model. However, neither the model nor the pre-trained weights of Imagen is available. To enable people to fine-tune a text-to-image model with a few examples, I implemented the idea of Dreambooth on Stable diffusion.
This code repository is based on that of [Textual Inversion](https://github.com/rinongal/textual_inversion). Note that Textual Inversion only optimizes word ebedding, while dreambooth fine-tunes the whole diffusion model.
The implementation makes minimum changes over the official codebase of Textual Inversion. In fact, due to lazyness, some components in Textual Inversion, such as the embedding manager, are not deleted, although they will never be used here.
## Update
**9/20/2022**: I just found a way to reduce the GPU memory a bit. Remember that this code is based on Textual Inversion, and TI's code base has [this line](https://github.com/rinongal/textual_inversion/blob/main/ldm/modules/diffusionmodules/util.py#L112), which disable gradient checkpointing in a hard-code way. This is because in TI, the Unet is not optimized. However, in Dreambooth we optimize the Unet, so we can turn on the gradient checkpoint pointing trick, as in the original SD repo [here](https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L112). The gradient checkpoint is default to be True in [config](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/blob/main/configs/stable-diffusion/v1-finetune_unfrozen.yaml#L47). I have updated the codes.
## Usage
### Preparation
First set-up the ```ldm``` enviroment following the instruction from textual inversion repo, or the original Stable Diffusion repo.
To fine-tune a stable diffusion model, you need to obtain the pre-trained stable diffusion models following their [instructions](https://github.com/CompVis/stable-diffusion#stable-diffusion-v1). Weights can be downloaded on [HuggingFace](https://huggingface.co/CompVis). You can decide which version of checkpoint to use, but I use ```sd-v1-4-full-ema.ckpt```.
We also need to create a set of images for regularization, as the fine-tuning algorithm of Dreambooth requires that. Details of the algorithm can be found in the paper. Note that in the original paper, the regularization images seem to be generated on-the-fly. However, here I generated a set of regularization images before the training. The text prompt for generating regularization images can be ```photo of a ```, where `````` is a word that describes the class of your object, such as ```dog```. The command is
```
python scripts/stable_txt2img.py --ddim_eta 0.0 --n_samples 8 --n_iter 1 --scale 10.0 --ddim_steps 50 --ckpt /path/to/original/stable-diffusion/sd-v1-4-full-ema.ckpt --prompt "a photo of a "
```
I generate 8 images for regularization, but more regularization images may lead to stronger regularization and better editability. After that, save the generated images (separately, one image per ```.png``` file) at ```/root/to/regularization/images```.
**Updates on 9/9**
We should definitely use more images for regularization. Please try 100 or 200, to better align with the original paper. To acomodate this, I shorten the "repeat" of reg dataset in the [config file](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/blob/main/configs/stable-diffusion/v1-finetune_unfrozen.yaml#L96).
For some cases, if the generated regularization images are highly unrealistic (happens when you want to generate "man" or "woman"), you can find a diverse set of images (of man/woman) online, and use them as regularization images.
### Training
Training can be done by running the following command
```
python main.py --base configs/stable-diffusion/v1-finetune_unfrozen.yaml
-t
--actual_resume /path/to/original/stable-diffusion/sd-v1-4-full-ema.ckpt
-n
--gpus 0,
--data_root /root/to/training/images
--reg_data_root /root/to/regularization/images
--class_word
```
Detailed configuration can be found in ```configs/stable-diffusion/v1-finetune_unfrozen.yaml```. In particular, the default learning rate is ```1.0e-6``` as I found the ```1.0e-5``` in the Dreambooth paper leads to poor editability. The parameter ```reg_weight``` corresponds to the weight of regularization in the Dreambooth paper, and the default is set to ```1.0```.
Dreambooth requires a placeholder word ```[V]```, called identifier, as in the paper. This identifier needs to be a relatively rare tokens in the vocabulary. The original paper approaches this by using a rare word in T5-XXL tokenizer. For simplicity, here I just use a random word ```sks``` and hard coded it.. If you want to change that, simply make a change in [this file](https://github.com/XavierXiao/Dreambooth-Stable-Diffusion/blob/main/ldm/data/personalized.py#L10).
Training will be run for 800 steps, and two checkpoints will be saved at ```./logs//checkpoints```, one at 500 steps and one at final step. Typically the one at 500 steps works well enough. I train the model use two A6000 GPUs and it takes ~15 mins.
### Generation
After training, personalized samples can be obtained by running the command
```
python scripts/stable_txt2img.py --ddim_eta 0.0
--n_samples 8
--n_iter 1
--scale 10.0
--ddim_steps 100
--ckpt /path/to/saved/checkpoint/from/training
--prompt "photo of a sks "
```
In particular, ```sks``` is the identifier, which should be replaced by your choice if you happen to change the identifier, and `````` is the class word ```--class_word``` for training.
## Results
Here I show some qualitative results. The training images are obtained from the [issue](https://github.com/rinongal/textual_inversion/issues/8) in the Textual Inversion repository, and they are 3 images of a large trash container. Regularization images are generated by prompt ```photo of a container```. Regularization images are shown here:

After training, generated images with prompt ```photo of a sks container```:

Generated images with prompt ```photo of a sks container on the beach```:

Generated images with prompt ```photo of a sks container on the moon```:

Some not-so-perfect but still interesting results:
Generated images with prompt ```photo of a red sks container```:

Generated images with prompt ```a dog on top of sks container```:
