https://github.com/cyberagentailab/regularized-bon
Code of "Regularized Best-of-N Sampling with Minimum Bayes Risk Objective for Language Model Alignment" (2025).
https://github.com/cyberagentailab/regularized-bon
Last synced: 9 months ago
JSON representation
Code of "Regularized Best-of-N Sampling with Minimum Bayes Risk Objective for Language Model Alignment" (2025).
- Host: GitHub
- URL: https://github.com/cyberagentailab/regularized-bon
- Owner: CyberAgentAILab
- License: mit
- Created: 2024-03-31T09:19:59.000Z (about 2 years ago)
- Default Branch: master
- Last Pushed: 2025-04-04T04:05:14.000Z (about 1 year ago)
- Last Synced: 2025-09-10T07:42:49.454Z (9 months ago)
- Language: Python
- Homepage: https://arxiv.org/abs/2404.01054
- Size: 55.7 KB
- Stars: 14
- Watchers: 1
- Forks: 1
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
## Regularized Best-of-N
Implementation of [Regularized Best-of-N (RBoN)](https://arxiv.org/abs/2404.01054).
The code is tested on Ubuntu 20.04 using Python 3.8 and CUDA 11.0 (Docker image nvidia/cuda:11.0.3-cudnn8-devel-ubuntu20.04).
```
git clone git@github.com:CyberAgentAILab/regularized-bon
cd regularized-bon
pip install -r requirements.txt
```
## Usage
Running RBoN takes multiple steps.
1. First you generate a set of responses using sample.sh. We use the same set of samples generated for all the algorithms for fair comparison.
2. Compute Wasserstein distance and KL divergence using compute_wd.sh and compute_logprob.sh.
3. Compute the reward of the responses.
3. Run mbr/compute_rbon.py to compute MBR-BoN (RBoN-WD) and RBoN-KL.
You get the CSV file in the results/ directory.
### Sampling candidates
By default, it runs using [openai-community/gpt2](https://huggingface.co/openai-community/gpt2). Add `-m [MODEL NAME IN HUGGINGFACE HUB]` to change the language model.
```
./experiments/sample.sh -d alpaca -s [NUMBER OF SAMPLES]
```
Due to the backward compatibility in my codebase, sample.py has to select a prompt file even for tasks like AlpacaFarm that don't have a prompt shared prompt for the task.
To this end, we have a [dummy.txt](https://github.com/CyberAgentAILab/regularized-bon/blob/master/prompts/dummy.txt) which is a blank file so that we can select this blank file to say that we don't have a shared prompt for the task.
### Computing Wasserstein distance
```
./experiments/compute_wd.sh -d alpaca -s [NUMBER OF SAMPLES]
```
### Computing log probability
```
./experiments/compute_logprob.sh -d alpaca -s [NUMBER OF SAMPLES]
```
### Computing the reward of the samples
```
./experiments/compute_reward.sh -d alpaca -s [NUMBER OF SAMPLES] -i stanfordnlp/SteamSHP-flan-t5-large
./experiments/compute_reward.sh -d alpaca -s [NUMBER OF SAMPLES] -i OpenAssistant/reward-model-deberta-v3-large-v2
```
### Computing MBR-BoN and RBoN_KL
```
python3 mbr/compute_rbon.py --dataset alpaca --ncandidates [NUMBER OF SAMPLES]
```
## Reference
Jinnai, Y., Morimura, T., Ariu, K., and Abe, K. Regularized Best-of-N Sampling with Minimum Bayes Risk Objective for Language Model Alignment. 2025.
Bibtex:
```
@misc{jinnai2025regularizedbestofnsamplingminimum,
title={Regularized Best-of-N Sampling with Minimum Bayes Risk Objective for Language Model Alignment},
author={Yuu Jinnai and Tetsuro Morimura and Kaito Ariu and Kenshi Abe},
year={2025},
eprint={2404.01054},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2404.01054},
}
```
## Contact
For any questions, feel free to raise an issue or contact me at jinnai_yu@cyberagent.co.jp.