https://github.com/ZiyiZhang27/tdpo
[ICML 2024] Code for the paper "Confronting Reward Overoptimization for Diffusion Models: A Perspective of Inductive and Primacy Biases"
https://github.com/ZiyiZhang27/tdpo
alignment diffusion-models human-feedback reinforcement-learning rlhf stable-diffusion text-to-image
Last synced: 5 months ago
JSON representation
[ICML 2024] Code for the paper "Confronting Reward Overoptimization for Diffusion Models: A Perspective of Inductive and Primacy Biases"
- Host: GitHub
- URL: https://github.com/ZiyiZhang27/tdpo
- Owner: ZiyiZhang27
- License: mit
- Created: 2024-05-19T06:29:40.000Z (11 months ago)
- Default Branch: main
- Last Pushed: 2024-05-20T15:11:00.000Z (11 months ago)
- Last Synced: 2024-05-21T13:44:15.120Z (11 months ago)
- Topics: alignment, diffusion-models, human-feedback, reinforcement-learning, rlhf, stable-diffusion, text-to-image
- Language: Python
- Homepage: https://arxiv.org/abs/2402.08552
- Size: 3.28 MB
- Stars: 3
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
- awesome-RLHF - official
- StarryDivineSky - ZiyiZhang27/tdpo - R) 的时间扩散策略优化,这是一种策略梯度算法,它利用了扩散模型的时间归纳偏差并减轻了源于活跃神经元的首要偏差。实证结果表明,我们的方法在缓解奖励过度优化方面具有卓越的功效。 (A01_文本生成_文本对话 / 大语言对话模型及数据)
README
# Temporal Diffusion Policy Optimization (TDPO)
This is an official PyTorch implementation of **Temporal Diffusion Policy Optimization (TDPO)** from our paper [*Confronting Reward Overoptimization for Diffusion Models: A Perspective of Inductive and Primacy Biases*](https://openreview.net/pdf?id=v2o9rRJcEv), which is accepted by **ICML 2024**.
## Installation
Python 3.10 or a newer version is required. In order to install the requirements, create a conda environment and run the `setup.py` file in this repository, e.g. run the following commands:```bash
conda create -p tdpo python=3.10.12 -y
conda activate tdpogit clone [email protected]:ZiyiZhang27/tdpo.git
cd tdpo
pip install -e .
```## Training
To train on **Aesthetic Score** and evaluate *cross-reward generalization* by out-of-domain reward functions, run this command:
```bash
accelerate launch scripts/train_tdpo.py --config config/config_tdpo.py:aesthetic
```
To train on **PickScore** and evaluate *cross-reward generalization* by out-of-domain reward functions, run this command:```bash
accelerate launch scripts/train_tdpo.py --config config/config_tdpo.py:pickscore
```To train on **HPSv2** and evaluate *cross-reward generalization* by out-of-domain reward functions, run this command:
```bash
accelerate launch scripts/train_tdpo.py --config config/config_tdpo.py:hpsv2
```For detailed explanations of all hyperparameters, please refer to the configuration files `config/base_tdpo.py` and `config/config_tdpo.py`. These files are pre-configured for training with 8 x NVIDIA A100 GPUs (each with 40GB of memory).
**Note:** Some hyperparameters might appear in both configuration files. In such cases, only the values set in `config/config_tdpo.py` will be used during training as this file has higher priority.
## Citation
If you find this work useful in your research, please consider citing:
```bibtex
@inproceedings{zhang2024confronting,
title={Confronting Reward Overoptimization for Diffusion Models: A Perspective of Inductive and Primacy Biases},
author={Ziyi Zhang and Sen Zhang and Yibing Zhan and Yong Luo and Yonggang Wen and Dacheng Tao},
booktitle={Forty-first International Conference on Machine Learning},
year={2024}
}
```## Acknowledgement
- This repository is built upon the [PyTorch codebase of DDPO](https://github.com/kvablack/ddpo-pytorch) developed by Kevin Black and his team. We are grateful for their contribution to the field.
- We also extend our thanks to Timo Klein for open-sourcing the [PyTorch reimplementation](https://github.com/timoklein/redo/) of [ReDo](https://arxiv.org/abs/2302.12902).
- We also acknowledge the contributions of [PickScore](https://github.com/yuvalkirstain/PickScore), [HPSv2](https://github.com/tgxs002/HPSv2), and [ImageReward](https://github.com/THUDM/ImageReward) projects to this work.