{"id":20472005,"url":"https://github.com/wtlow003/ngram-decoding","last_synced_at":"2025-03-05T13:46:42.880Z","repository":{"id":253786851,"uuid":"844530347","full_name":"wtlow003/ngram-decoding","owner":"wtlow003","description":"(Re)-implementation of \"Prompt Lookup Decoding\" by Apoorv Saxena, with extended ideas from LLMA Decoding.","archived":false,"fork":false,"pushed_at":"2024-08-20T13:47:11.000Z","size":565,"stargazers_count":0,"open_issues_count":0,"forks_count":0,"subscribers_count":1,"default_branch":"master","last_synced_at":"2025-01-16T02:24:35.919Z","etag":null,"topics":["llm-inference","n-gram","ngram-decoding","prompt-lookup-decoding","speculative-decoding"],"latest_commit_sha":null,"homepage":"https://www.jensenlwt.com/blog/understanding-speculative-decoding-for-llm-inference","language":"Jupyter Notebook","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/wtlow003.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2024-08-19T12:57:20.000Z","updated_at":"2024-08-23T06:47:31.000Z","dependencies_parsed_at":"2024-08-21T19:05:49.252Z","dependency_job_id":null,"html_url":"https://github.com/wtlow003/ngram-decoding","commit_stats":null,"previous_names":["wtlow003/ngram-decoding"],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/wtlow003%2Fngram-decoding","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/wtlow003%2Fngram-decoding/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/wtlow003%2Fngram-decoding/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/wtlow003%2Fngram-decoding/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/wtlow003","download_url":"https://codeload.github.com/wtlow003/ngram-decoding/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":242039620,"owners_count":20061925,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["llm-inference","n-gram","ngram-decoding","prompt-lookup-decoding","speculative-decoding"],"created_at":"2024-11-15T14:17:51.930Z","updated_at":"2025-03-05T13:46:42.851Z","avatar_url":"https://github.com/wtlow003.png","language":"Jupyter Notebook","readme":"\u003ch1 align=\"center\"\u003eN-gram Decoding\u003c/h1\u003e\n\n\u003cp align=\"center\"\u003e\n    \u003cimg src=\"https://img.shields.io/badge/python-3.9.10-orange\"\n         alt=\"python version\"\u003e\n     \u003cimg src=\"https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/astral-sh/uv/main/assets/badge/v0.json\"\n          alt=\"uv\"\u003e\n    \u003cimg src=\"https://img.shields.io/endpoint?url=https://raw.githubusercontent.com/charliermarsh/ruff/main/assets/badge/v1.json\"\n         alt=\"ruff\"\u003e\n\u003c/p\u003e\n\n## About\n\nThis repository contains the implementation of the ngram-decoding (aka *prompt lookup decoding*) method for faster LLM inference.\n\nThis exploration aims to understand the using n-grams for loseless accelaration of LLM inference, as proposed in: \n\n1. [Prompt Lookup Decoding](https://github.com/apoorvumang/prompt-lookup-decoding?tab=readme-ov-file)\n2. [LLMA Decoding](https://github.com/microsoft/LMOps/tree/main/llma)\n\nCombining the core ideas from both methods, I explored the following algorithm built upon the aforementioned works:\n\n1. Match the n-grams in the prompt with the tokens in the input sequence, and obtain `K` candidate tokens.\n2. If multiple candidates are found, select the set with the most candidate tokens. In case of a tie, a random selection is made.\n3. If no candidate tokens are identified, default to single-step greedy decoding.\n\n\u003e [!NOTE]\n\u003e The number of tokens generated per step in n-gram decoding ranges from `1` to `K+1`.\n\n4. Repeat the above steps until either the maximum `n` number of tokens is reached or the `EOS` (e.g., `\u003c|eot_id|\u003e`) token is generated.\n\n## Getting Started\n\nThis project uses uv for dependency management. To install UV, run the following command:\n\n```bash\n# On macOS and Linux.\ncurl -LsSf https://astral.sh/uv/install.sh | sh\n\n# On Windows.\npowershell -c \"irm https://astral.sh/uv/install.ps1 | iex\"\n\n# With pip.\npip install uv\n\n# With pipx.\npipx install uv\n\n# With Homebrew.\nbrew install uv\n\n# With Pacman.\npacman -S uv\n```\n\nThereafter, install the rest of the dependencies using uv:\n\n```bash\n# create a virtual env\nuv venv\n\n# install dependencies\nuv pip install -r requirements.txt  # Install from a requirements.txt file.\n```\n\n## Usage\n\n\u003e [!NOTE]\n\u003e\n\u003e Currently, the script only supports `Meta-Llama-3.1-8B-Instruct` model.\n\n```bash\n# check cli options\npython main.py --help\n\nusage: main.py [-h] [--model MODEL] --decoding-method {greedy,ngram}\n\noptional arguments:\n  -h, --help            show this help message and exit\n  --model MODEL\n  --decoding-method {greedy,ngram}\n```\n\nRunning LLM inference comparison script:\n\n```bash\n# ngram decoding\npython main.py --model meta-llama/Meta-Llama-3.1-8B-Instruct \\\n    --decoding-method ngram\n\n# greedy decoding\npython main.py --model meta-llama/Meta-Llama-3.1-8B-Instruct \\\n    --decoding-method greedy\n```\n\n## Results\n\nThe following results are obtained on `A100` GPU with `40GB` RAM, with the following settings:\n\n1. `ngrams_size` = 3\n2. `K` = 10\n3. `n` = 400\n\nhttps://github.com/user-attachments/assets/5b103571-a9ea-4e46-ad52-c3f91589c83e\n\nUsing the following example prompt:\n\n```\n\u003c|start_header_id|\u003euser\u003c|end_header_id|\u003e\nCode:\n```python\n    def generate_candidate_tokens(\n        input_ids: torch.Tensor, n_grams: torch.Tensor, ngrams_size: int, K: int\n    ):\n        # unfold the tensor into windows of `pattern_len + following_elements_count`\n        window = input_ids.unfold(dimension=1, size=ngrams_size, step=1)\n        # compare each window with the pattern (only the parts corresponding to the pattern)\n        matching_window_indices = (window == n_grams).all(dim=2)\n        # extract the indices where there are matches\n        matching_indices = matching_window_indices.nonzero(as_tuple=True)[1]\n\n        # find candidates with the longest length\n        # based on: https://arxiv.org/pdf/2304.04487\n        # we choose the candidate with the longest length at random if there are multiple candidates\n        candidates = []\n        max_length = K\n        for idx in matching_indices:\n            start_idx = idx + ngrams_size\n            end_idx = start_idx + K\n            candidate = input_ids[0, start_idx : min(end_idx, input_ids.size(1))]\n            length = len(candidate)\n\n            if length == max_length:\n                candidates.append(candidate)\n            else:\n                # we do not consider prefix with no candidates\n                if length \u003e max_length:\n                    max_length = length\n                    candidates = [candidate]\n\n        if candidates:\n            chosen_candidate = candidates[np.random.randint(len(candidates))]\n        else:\n            chosen_candidate = torch.tensor([], dtype=torch.long, device=input_ids.device)\n\n        return chosen_candidate.unsqueeze(dim=0)\n    ``` \n\n Question: Can you the variable name 'candidates' to 'candidates_tokens'? \n\n Modified code:\n\u003c|start_header_id|\u003eassistant\u003c|end_header_id|\u003e\n```\n\nThe following timings are observed:\n\n|    Decoding Method   |  Time Taken (s)  |  Token/secs  |   Speedup   |\n| -------------------- | ---------------- | ------------ | ----------- |\n|    Greedy Decoding   |      26.4        |     14.0     |      1x     | \n|    Ngrams Decoding   |      12.8        |     28.9     |     ~2x     | \n\nIn the simple demonstration experiment, we achieved results comparable to those of the original [Prompt Lookup Decoding](https://github.com/apoorvumang/prompt-lookup-decoding?tab=readme-ov-file) implementation and the figures reported in [LLMA Decoding](https://github.com/microsoft/LMOps/tree/main/llma). Both decoding methods demonstrated approximately a 2-3x improvement in speed over greedy decoding.\n\n## References\n\n```\n@misc{saxena2023prompt,\n    title = {Prompt Lookup Decoding},\n    author = {Apoorv Saxena},\n    year = {2023},\n    month = {November},\n    url = {https://github.com/apoorvumang/prompt-lookup-decoding/}\n}\n\n@misc{yang2023inferencereferencelosslessacceleration,\n      title={Inference with Reference: Lossless Acceleration of Large Language Models}, \n      author={Nan Yang and Tao Ge and Liang Wang and Binxing Jiao and Daxin Jiang and Linjun Yang and Rangan Majumder and Furu Wei},\n      year={2023},\n      eprint={2304.04487},\n      archivePrefix={arXiv},\n      primaryClass={cs.CL},\n      url={https://arxiv.org/abs/2304.04487}, \n}\n```\n\n## Acknowledgements\n\nThe implementation for ngram-decoding is build upon the following repository:\n\n1. https://github.com/apoorvumang/prompt-lookup-decoding?tab=readme-ov-file\n2. https://github.com/microsoft/LMOps/tree/main/llma\n","funding_links":[],"categories":[],"sub_categories":[],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fwtlow003%2Fngram-decoding","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fwtlow003%2Fngram-decoding","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fwtlow003%2Fngram-decoding/lists"}