Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/wtlow003/ngram-decoding
(Re)-implementation of "Prompt Lookup Decoding" by Apoorv Saxena, with extended ideas from LLMA Decoding.
https://github.com/wtlow003/ngram-decoding
llm-inference n-gram ngram-decoding prompt-lookup-decoding speculative-decoding
Last synced: 4 days ago
JSON representation
(Re)-implementation of "Prompt Lookup Decoding" by Apoorv Saxena, with extended ideas from LLMA Decoding.
- Host: GitHub
- URL: https://github.com/wtlow003/ngram-decoding
- Owner: wtlow003
- License: mit
- Created: 2024-08-19T12:57:20.000Z (5 months ago)
- Default Branch: master
- Last Pushed: 2024-08-20T13:47:11.000Z (5 months ago)
- Last Synced: 2024-11-15T14:17:45.788Z (2 months ago)
- Topics: llm-inference, n-gram, ngram-decoding, prompt-lookup-decoding, speculative-decoding
- Language: Jupyter Notebook
- Homepage: https://www.jensenlwt.com/blog/understanding-speculative-decoding-for-llm-inference
- Size: 552 KB
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
N-gram Decoding
## About
This repository contains the implementation of the ngram-decoding (aka *prompt lookup decoding*) method for faster LLM inference.
This exploration aims to understand the using n-grams for loseless accelaration of LLM inference, as proposed in:
1. [Prompt Lookup Decoding](https://github.com/apoorvumang/prompt-lookup-decoding?tab=readme-ov-file)
2. [LLMA Decoding](https://github.com/microsoft/LMOps/tree/main/llma)Combining the core ideas from both methods, I explored the following algorithm built upon the aforementioned works:
1. Match the n-grams in the prompt with the tokens in the input sequence, and obtain `K` candidate tokens.
2. If multiple candidates are found, select the set with the most candidate tokens. In case of a tie, a random selection is made.
3. If no candidate tokens are identified, default to single-step greedy decoding.> [!NOTE]
> The number of tokens generated per step in n-gram decoding ranges from `1` to `K+1`.4. Repeat the above steps until either the maximum `n` number of tokens is reached or the `EOS` (e.g., `<|eot_id|>`) token is generated.
## Getting Started
This project uses uv for dependency management. To install UV, run the following command:
```bash
# On macOS and Linux.
curl -LsSf https://astral.sh/uv/install.sh | sh# On Windows.
powershell -c "irm https://astral.sh/uv/install.ps1 | iex"# With pip.
pip install uv# With pipx.
pipx install uv# With Homebrew.
brew install uv# With Pacman.
pacman -S uv
```Thereafter, install the rest of the dependencies using uv:
```bash
# create a virtual env
uv venv# install dependencies
uv pip install -r requirements.txt # Install from a requirements.txt file.
```## Usage
> [!NOTE]
>
> Currently, the script only supports `Meta-Llama-3.1-8B-Instruct` model.```bash
# check cli options
python main.py --helpusage: main.py [-h] [--model MODEL] --decoding-method {greedy,ngram}
optional arguments:
-h, --help show this help message and exit
--model MODEL
--decoding-method {greedy,ngram}
```Running LLM inference comparison script:
```bash
# ngram decoding
python main.py --model meta-llama/Meta-Llama-3.1-8B-Instruct \
--decoding-method ngram# greedy decoding
python main.py --model meta-llama/Meta-Llama-3.1-8B-Instruct \
--decoding-method greedy
```## Results
The following results are obtained on `A100` GPU with `40GB` RAM, with the following settings:
1. `ngrams_size` = 3
2. `K` = 10
3. `n` = 400https://github.com/user-attachments/assets/5b103571-a9ea-4e46-ad52-c3f91589c83e
Using the following example prompt:
```
<|start_header_id|>user<|end_header_id|>
Code:
```python
def generate_candidate_tokens(
input_ids: torch.Tensor, n_grams: torch.Tensor, ngrams_size: int, K: int
):
# unfold the tensor into windows of `pattern_len + following_elements_count`
window = input_ids.unfold(dimension=1, size=ngrams_size, step=1)
# compare each window with the pattern (only the parts corresponding to the pattern)
matching_window_indices = (window == n_grams).all(dim=2)
# extract the indices where there are matches
matching_indices = matching_window_indices.nonzero(as_tuple=True)[1]# find candidates with the longest length
# based on: https://arxiv.org/pdf/2304.04487
# we choose the candidate with the longest length at random if there are multiple candidates
candidates = []
max_length = K
for idx in matching_indices:
start_idx = idx + ngrams_size
end_idx = start_idx + K
candidate = input_ids[0, start_idx : min(end_idx, input_ids.size(1))]
length = len(candidate)if length == max_length:
candidates.append(candidate)
else:
# we do not consider prefix with no candidates
if length > max_length:
max_length = length
candidates = [candidate]if candidates:
chosen_candidate = candidates[np.random.randint(len(candidates))]
else:
chosen_candidate = torch.tensor([], dtype=torch.long, device=input_ids.device)return chosen_candidate.unsqueeze(dim=0)
```Question: Can you the variable name 'candidates' to 'candidates_tokens'?
Modified code:
<|start_header_id|>assistant<|end_header_id|>
```The following timings are observed:
| Decoding Method | Time Taken (s) | Token/secs | Speedup |
| -------------------- | ---------------- | ------------ | ----------- |
| Greedy Decoding | 26.4 | 14.0 | 1x |
| Ngrams Decoding | 12.8 | 28.9 | ~2x |In the simple demonstration experiment, we achieved results comparable to those of the original [Prompt Lookup Decoding](https://github.com/apoorvumang/prompt-lookup-decoding?tab=readme-ov-file) implementation and the figures reported in [LLMA Decoding](https://github.com/microsoft/LMOps/tree/main/llma). Both decoding methods demonstrated approximately a 2-3x improvement in speed over greedy decoding.
## References
```
@misc{saxena2023prompt,
title = {Prompt Lookup Decoding},
author = {Apoorv Saxena},
year = {2023},
month = {November},
url = {https://github.com/apoorvumang/prompt-lookup-decoding/}
}@misc{yang2023inferencereferencelosslessacceleration,
title={Inference with Reference: Lossless Acceleration of Large Language Models},
author={Nan Yang and Tao Ge and Liang Wang and Binxing Jiao and Daxin Jiang and Linjun Yang and Rangan Majumder and Furu Wei},
year={2023},
eprint={2304.04487},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2304.04487},
}
```## Acknowledgements
The implementation for ngram-decoding is build upon the following repository:
1. https://github.com/apoorvumang/prompt-lookup-decoding?tab=readme-ov-file
2. https://github.com/microsoft/LMOps/tree/main/llma