Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/SalesforceAIResearch/DiffusionDPO
Code for "Diffusion Model Alignment Using Direct Preference Optimization"
https://github.com/SalesforceAIResearch/DiffusionDPO
Last synced: 4 months ago
JSON representation
Code for "Diffusion Model Alignment Using Direct Preference Optimization"
- Host: GitHub
- URL: https://github.com/SalesforceAIResearch/DiffusionDPO
- Owner: SalesforceAIResearch
- License: apache-2.0
- Created: 2023-12-27T22:52:26.000Z (about 1 year ago)
- Default Branch: main
- Last Pushed: 2023-12-28T01:07:46.000Z (about 1 year ago)
- Last Synced: 2024-08-01T13:28:29.585Z (7 months ago)
- Language: Python
- Homepage: https://arxiv.org/abs/2311.12908
- Size: 4.57 MB
- Stars: 200
- Watchers: 7
- Forks: 18
- Open Issues: 10
-
Metadata Files:
- Readme: README.md
- Contributing: CONTRIBUTING.md
- License: LICENSE.txt
- Code of conduct: CODE_OF_CONDUCT.md
- Codeowners: CODEOWNERS
- Security: SECURITY.md
Awesome Lists containing this project
README
# Intro
This is the training code for [Diffusion-DPO](https://arxiv.org/abs/2311.12908). The script is adapted from [the diffusers library](https://github.com/huggingface/diffusers/tree/v0.20.0-release/examples/text_to_image).
# Model Checkpoints
The below are initialized with StableDiffusion models and trained as described in the paper (replicable with [launchers/](launchers/) scripts assuming 16 GPUs, scale gradient accumulation accordingly).
[StableDiffusion1.5](https://huggingface.co/mhdang/dpo-sd1.5-text2image-v1)
[StableDiffusion-XL-1.0](https://huggingface.co/mhdang/dpo-sdxl-text2image-v1?text=Test)
Use this [notebook](quick_samples.ipynb) to compare generations. It also has a sample of automatic quantative evaluation using PickScore.
# Setup
`pip install -r requirements.txt`
# Structure
- `launchers/` is examples of running SD1.5 or SDXL training
- `utils/` has the scoring models for evaluation or AI feedback (PickScore, HPS, Aesthetics, CLIP)
- `quick_samples.ipynb` is visualizations from a pretrained model vs baseline
- `requirements.txt` Basic pip requirements
- `train.py` Main script, this is pretty bulky at >1000 lines, training loop starts at ~L1000 at this commit (`ctrl-F` "for epoch").
- `upload_model_to_hub.py` Uploads a model checkpoint to HF (simple utility, current values are placeholder)# Running the training
Example SD1.5 launch
```bash
# from launchers/sd15.sh
export MODEL_NAME="runwayml/stable-diffusion-v1-5"
export DATASET_NAME="yuvalkirstain/pickapic_v2"# Effective BS will be (N_GPU * train_batch_size * gradient_accumulation_steps)
# Paper used 2048. Training takes ~24 hours / 2000 stepsaccelerate launch --mixed_precision="fp16" train.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--dataset_name=$DATASET_NAME \
--train_batch_size=1 \
--dataloader_num_workers=16 \
--gradient_accumulation_steps=1 \
--max_train_steps=2000 \
--lr_scheduler="constant_with_warmup" --lr_warmup_steps=500 \
--learning_rate=1e-8 --scale_lr \
--cache_dir="/export/share/datasets/vision_language/pick_a_pic_v2/" \
--checkpointing_steps 500 \
--beta_dpo 5000 \
--output_dir="tmp-sd15"
```## Important Args
### General
- `--pretrained_model_name_or_path` what model to train/initalize from
- `--output_dir` where to save/log to
- `--seed` training seed (not set by default)
- `--sdxl` run SDXL training
- `--sft` run SFT instead of DPO### DPO
- `--beta_dpo` KL-divergence parameter beta for DPO
- `--choice_model` Model for AI feedback (Aesthetics, CLIP, PickScore, HPS)### Optimizers/learning rates
- `--max_train_steps` How many train steps to take
- `--gradient_accumulation_steps`
- `--train_batch_size` see above notes in script for actual BS
- `--checkpointing_steps` how often to save model
- `--gradient_checkpointing` turned on automatically for SDXL- `--learning_rate`
- `--scale_lr` Found this to be very helpful but isn't default in code
- `--lr_scheduler` Type of LR warmup/decay. Default is linear warmup to constant
- `--lr_warmup_steps` number of scheduler warmup steps
- `--use_adafactor` Adafactor over Adam (lower memory, default for SDXL)### Data
- `--dataset_name` if you want to switch from Pick-a-Pic
- `--cache_dir` where dataset is cached locally **(users will want to change this to fit their file system)**
- `--resolution` defaults to 512 for non-SDXL, 1024 for SDXL.
- `--random_crop` and `--no_hflip` changes data aug
- `--dataloader_num_workers` number of total dataloader workers# Citation
```
@misc{wallace2023diffusion,
title={Diffusion Model Alignment Using Direct Preference Optimization},
author={Bram Wallace and Meihua Dang and Rafael Rafailov and Linqi Zhou and Aaron Lou and Senthil Purushwalkam and Stefano Ermon and Caiming Xiong and Shafiq Joty and Nikhil Naik},
year={2023},
eprint={2311.12908},
archivePrefix={arXiv},
primaryClass={cs.CV}```