Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/liziniu/ReMax
Code for Paper (ReMax: A Simple, Efficient and Effective Reinforcement Learning Method for Aligning Large Language Models)
https://github.com/liziniu/ReMax
large-language-models policy-gradient reinforcement-learning rlhf
Last synced: about 1 month ago
JSON representation
Code for Paper (ReMax: A Simple, Efficient and Effective Reinforcement Learning Method for Aligning Large Language Models)
- Host: GitHub
- URL: https://github.com/liziniu/ReMax
- Owner: liziniu
- Created: 2023-10-17T05:25:36.000Z (about 1 year ago)
- Default Branch: master
- Last Pushed: 2023-12-16T12:38:36.000Z (about 1 year ago)
- Last Synced: 2024-08-03T09:06:45.988Z (5 months ago)
- Topics: large-language-models, policy-gradient, reinforcement-learning, rlhf
- Language: Python
- Homepage:
- Size: 1.76 MB
- Stars: 135
- Watchers: 2
- Forks: 12
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
- StarryDivineSky - liziniu/ReMax - 4 判断时,ReMax 的胜率分别比 SFT、DPO 和 PPO 高出 84.22%、75.28% 和 63.60%。 (A01_文本生成_文本对话 / 大语言对话模型及数据)
README
# ReMax: A Simple, Effective, and Efficient Method for Aligning Large Language Models
## Overview
ReMax is a reinforcement learning method, tailored for reward maximization in RLHF.
#### Simple Implementation
ReMax is easy to implement (with **6 lines of code**). We provide an implementation based on the DeepSpeed framework in this resposity.
#### Memory Efficient
ReMax is memory-efficient. Compared with PPO, ReMax can save about 50% GPU memory consumption, which could be allocated for **1.3x large batch size**.
Results of tuning Llama2-7B with A100-80GB GPUs
| GPUs | Offload | Method | Maximum Batch Size |
| ---- | ------- | ------ | ------------------ |
| 4 | False | PPO | ❌ (OOM) |
| 4 | False | ReMax | **4x26=104** |
| 4 | True | PPO | 4x30=120 |
| 4 | True | ReMax | **4x40=160** |
| 1 | True | PPO | 1x32=32 |
| 1 | True | ReMax | **1x42=42** |*: Gradient checkpointing and ZeRO-2 are used for LLM.
*: ZeRO-3 and offload are used for the reward model and the reference model.
#### Fast Training
ReMax runs fast. It does not need to train a value model and requires fewer computations. Usually, it can achieve about **2x training speed-up**.
Results of tuning Llama2-7B with A100-80GB GPUs
| GPUs | Offload | Method | Total Training Time |
| ---- | ------- | ------ | ------------------- |
| 4 | False | PPO | ❌ (OOM) |
| 4 | False | ReMax | **2.4h** |
| 4 | True | PPO | 6.0h |
| 4 | True | ReMax | **2.8h** |
| 1 | True | PPO | 22.0h |
| 1 | True | ReMax | **10.2h** |*: Gradient checkpointing and ZeRO-2 are used for LLM.
*: ZeRO-3 and offload are used for the reward model and the reference model.
*: Measurement is based on 45k training samples (with 1 epoch) from the full-hh-rlhf dataset.
#### Easy to Tune
ReMax is easy to tune for good performance. On the AlpacaEval benchmark, when judeged by GPT-4, ReMax achieves win rates of 84.22%, 75.28%, and 63.60% over SFT, DPO, and PPO, respectively.
## Change Log
- [2023-12-16] Add response samples of trained models and evaluation results of training speed.
- [2023-10-18] Release the initial code.## How to use
### Prepare
The Python environment can be set up using Anaconda with the provided `environment.yml` file.
```
conda env create -f environment.yml
conda activate llm
```### Step 1 SFT
```
cd step1_supervised_finetuning# OPT(1.3B)
bash training_scripts/opt/run_opt_1.3b.sh# Llama2(7B)
bash training_scripts/llama2/run_llama2_1.3b.sh
```### Step 2 Reward Learning
```
cd step2_reward_model_finetuning# OPT(1.3B)
bash training_scripts/opt/run_opt_1.3b.sh# Llama2(7B)
bash training_scripts/llama2/run_llama2_1.3b.sh
```### Step 3 RLHF
```
cd step3_rlhf_finetuning# OPT(1.3B)
bash training_scripts/opt/run_opt_1.3b.sh# Llama2(7B)
bash training_scripts/llama2/run_llama2_1.3b.sh
```## Acknowledgements
Our code is heavily based on the [DeepSpeed-Chat](https://github.com/microsoft/DeepSpeedExamples/tree/master/applications/DeepSpeed-Chat). Please follow the detailed instructions from DeepSpeed-Chat.
## Bibtex
If you find this code is helpful, please cite our paper in the following format.
```
@article{li2023remax,
title = {ReMax: A Simple, Effective, and Efficient Method for Aligning Large Language Models},
author = {Li, Ziniu and Xu, Tian and Zhang, Yushun and Yu, Yang and Sun, RUoyu and Luo, Zhi-Quan},
booktitle = {arXiv preprint arXiv:2310.10505},
year = {2023},
}
```