Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/chaofengc/texforce
Official PyTorch codes for "Enhancing Diffusion Models with Text-Encoder Reinforcement Learning", ECCV2024
https://github.com/chaofengc/texforce
Last synced: 12 days ago
JSON representation
Official PyTorch codes for "Enhancing Diffusion Models with Text-Encoder Reinforcement Learning", ECCV2024
- Host: GitHub
- URL: https://github.com/chaofengc/texforce
- Owner: chaofengc
- License: other
- Created: 2023-11-27T01:08:08.000Z (12 months ago)
- Default Branch: main
- Last Pushed: 2024-08-13T13:21:16.000Z (3 months ago)
- Last Synced: 2024-10-16T04:14:09.076Z (27 days ago)
- Language: Python
- Homepage:
- Size: 62.9 MB
- Stars: 46
- Watchers: 3
- Forks: 4
- Open Issues: 3
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
- awesome-diffusion-categorized - [Code
README
# Enhancing Diffusion Models with Text-Encoder Reinforcement Learning
Official PyTorch codes for paper [Enhancing Diffusion Models with Text-Encoder Reinforcement Learning](https://arxiv.org/abs/2311.15657)
[![arXiv](https://img.shields.io/badge/arXiv-Paper-.svg)](https://arxiv.org/abs/2311.15657)
[![huggingface](https://img.shields.io/badge/HuggingFace-model-red.svg)](https://huggingface.co/chaofengc/sd-turbo_texforce)
![visitors](https://visitor-badge.laobi.icu/badge?page_id=chaofengc/TexForce)![teaser_img](./assets/fig_teaser.jpg)
## Requirements & Installation
- Clone the repo and install required packages with
```
# git clone this repository
git clone https://github.com/chaofengc/TexForce.git
cd TexForce# create new anaconda env
conda create -n texforce python=3.8
source activate texforce# install python dependencies
pip3 install -r requirements.txt
```## Results on SDXL-Turbo
We also applied our method to the recent model [sdxl-turbo](https://huggingface.co/stabilityai/sdxl-turbo). The model is trained with [ImageReward](https://github.com/THUDM/ImageReward) feedback through direct back-propagation to save training time. Test with the following codes
```
## Note: sdturboxl requires latest diffusers installed from source with the following command
git clone https://github.com/huggingface/diffusers
cd diffusers
pip install -e .
```
```
from diffusers import AutoPipelineForText2Image
import torchpipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16")
pipe = pipe.to("cuda")
pipe.load_lora_weights('chaofengc/sdxl-turbo_texforce')pt = ['a photo of a cat.']
img = pipe(prompt=pt, num_inference_steps=1, guidance_scale=0.0).images[0]
```Here are some example results:
sdxl-turbo
sdxl-turbo + TexForce
A photo of a cat.
An astronaut riding a horse.
water bottle.
## Results on SD-Turbo
We applied our method to the recent model [sdturbo](https://huggingface.co/stabilityai/sd-turbo). The model is trained with [Q-Instruct](https://github.com/Q-Future/Q-Instruct) feedback through direct back-propagation to save training time. Test with the following codes
```
## Note: sdturbo requires latest diffusers>=0.24.0 with AutoPipelineForText2Image classfrom diffusers import AutoPipelineForText2Image
from peft import PeftModel
import torchpipe = AutoPipelineForText2Image.from_pretrained("stabilityai/sd-turbo", torch_dtype=torch.float16, variant="fp16")
pipe = pipe.to("cuda")
PeftModel.from_pretrained(pipe.text_encoder, 'chaofengc/sd-turbo_texforce')pt = ['a photo of a cat.']
img = pipe(prompt=pt, num_inference_steps=1, guidance_scale=0.0).images[0]
```Here are some example results:
sd-turbo
sd-turbo + TexForce
A photo of a cat.
A photo of a dog.
A photo of a boy, colorful.
## Results on SD-1.4, SD-1.5, SD-2.1
Due to code compatibility, you need to install the following diffusers first:
```
pip uninstall diffusers
pip install diffusers==0.16.0
```You may simply load the pretrained lora weights with the following code block to improve performance of original stable diffusion model:
```
from diffusers import StableDiffusionPipeline
from diffusers import DDIMScheduler
from peft import PeftModel
import torchdef load_model_weights(pipe, weight_path, model_type):
if model_type == 'text+lora':
text_encoder = pipe.text_encoder
PeftModel.from_pretrained(text_encoder, weight_path)
elif model_type == 'unet+lora':
pipe.unet.load_attn_procs(weight_path)device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_id = "CompVis/stable-diffusion-v1-4"
pipe = StableDiffusionPipeline.from_pretrained(model_id, dtype=torch.float16)
pipe = pipe.to("cuda")
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)load_model_weights(pipe, './lora_weights/sd14_refl/', 'unet+lora')
load_model_weights(pipe, './lora_weights/sd14_texforce/', 'text+lora')prompt = ['a painting of a dog.']
img = pipe(prompt).images[0]```
Here are some example results:
SDv1.4
ReFL
TexForce
ReFL+TexForce
astronaut drifting afloat in space, in the darkness away from anyone else, alone, black background dotted with stars, realistic
portrait of a cute cyberpunk cat, realistic, professional
a coffee mug made of cardboard
## Training
We rewrite the training codes based on [trl](https://github.com/huggingface/trl) with the latest diffusers library.
> [!NOTE]
> The latest diffusers support simple loading of lora weights with `pipeline.load_lora_weights` after training.You may train the model with the following command:
### Example script for single prompt training
```
accelerate launch --num_processes 2 src/train_ddpo.py \
--mixed_precision="fp16" \
--sample_num_steps 50 --train_timestep_fraction 0.5 \
--num_epochs 40 \
--sample_batch_size 4 --sample_num_batches_per_epoch 64 \
--train_batch_size 4 --train_gradient_accumulation_steps 1 \
--prompt="single" --single_prompt_type="hand" --reward_list="handdetreward" \
--per_prompt_stat_tracking=True \
--tracker_project_name="texforce_hand" \
--log_with="tensorboard"
```
The supported prompts and reward functions are listed below:
- prompts: `hand`, `face`, `color`, `count`, `comp`, `location`
- rewards: `handdetreward`, `topiq_nr-face`, `imagereward`### Example script for complex multi-prompt training
```
accelerate launch --num_processes 2 src/train_ddpo.py \
--mixed_precision="fp16" \
--sample_num_steps 50 --train_timestep_fraction 0.5 \
--num_epochs 50 \
--sample_batch_size 4 --sample_num_batches_per_epoch 128 \
--train_batch_size 4 --train_gradient_accumulation_steps 4 \
--prompt="imagereward" --reward_list="imagereward" \
--per_prompt_stat_tracking=True \
--tracker_project_name="texforce_imgreward" \
--log_with="tensorboard"
```
The supported prompts and reward functions are:
- prompts: `imagereward`, `hps`
- rewards: `imagereward`, `hpsreward`, `laion_aes`## Citation
If you find this code useful for your research, please cite our paper:
```
@inproceedings{chen2024texforce,
title={Enhancing Diffusion Models with Text-Encoder Reinforcement Learning},
author={Chaofeng Chen and Annan Wang and Haoning Wu and Liang Liao and Wenxiu Sun and Qiong Yan and Weisi Lin},
year={2024},
booktitle = {Proceedings of the European Conference on Computer Vision (ECCV)},
}
```## License
This work is licensed under [NTU S-Lab License 1.0](./LICENCE_S-Lab) and a Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
## Acknowledgement
This project is largely based on [trl](https://github.com/huggingface/trl). The hand detection codes are taken from [Unified-Gesture-and-Fingertip-Detection](https://github.com/MahmudulAlam/Unified-Gesture-and-Fingertip-Detection). Many thanks to their great work :hugs:!