https://github.com/carperai/drlx
Diffusion Reinforcement Learning Library
https://github.com/carperai/drlx
Last synced: about 1 year ago
JSON representation
Diffusion Reinforcement Learning Library
- Host: GitHub
- URL: https://github.com/carperai/drlx
- Owner: CarperAI
- License: mit
- Created: 2023-05-31T22:53:31.000Z (about 3 years ago)
- Default Branch: main
- Last Pushed: 2024-02-13T04:14:06.000Z (over 2 years ago)
- Last Synced: 2025-03-31T11:08:16.121Z (about 1 year ago)
- Language: Python
- Homepage:
- Size: 4.46 MB
- Stars: 184
- Watchers: 10
- Forks: 8
- Open Issues: 9
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Diffusion Reinforcement Learning X
DRLX is a library for distributed training of diffusion models via RL. It is meant to wrap around 🤗 Hugging Face's [Diffusers](https://huggingface.co/docs/diffusers/) library and uses [Accelerate](https://huggingface.co/docs/accelerate/) for Multi-GPU and Multi-Node (as of yet untested)
**News (09/27/2023): Check out our blog post with some recent experiments [here](https://carper.ai/enhancing-diffusion-models-with-reinforcement-learning/)!**
📖 **[Documentation](https://DRLX.readthedocs.io)**
# Setup
First make sure you've installed [OpenCLIP](https://github.com/openai/CLIP.git). Afterwards, you can install the library from pypi:
```sh
pip install drlx
```
or from source:
```sh
pip install git+https://github.com/CarperAI/DRLX.git
```
# How to use
Currently we have only tested the library with Stable Diffusion 1.4, 1.5, and 2.1, but the plug and play nature of it means that realistically any denoiser from most pipelines should be usable. Models saved with DRLX are compatible with the pipeline they originated from and can be loaded like any other pretrained model. Currently the only algorithm supported for training is [DDPO](https://arxiv.org/abs/2305.13301).
```python
from drlx.reward_modelling.aesthetics import Aesthetics
from drlx.pipeline.pickapic_prompts import PickAPicPrompts
from drlx.trainer.ddpo_trainer import DDPOTrainer
from drlx.configs import DRLXConfig
# We import a reward model, a prompt pipeline, the trainer and config
pipe = PickAPicPrompts()
config = DRLXConfig.load_yaml("configs/my_cfg.yml")
trainer = DDPOTrainer(config)
trainer.train(pipe, Aesthetics())
```
And then to use a trained model for inference:
```python
pipe = StableDiffusionPipeline.from_pretrained("out/ddpo_exp")
prompt = "A mad panda scientist"
image = pipe(prompt).images[0]
image.save("test.jpeg")
```
# Accelerated Training
```bash
accelerate config
accelerate launch -m [your module]
```
# Roadmap
- [x] Initial launch and DDPO
- [x] PickScore Tuned Models
- [ ] DPO
- [ ] SDXL support