https://github.com/shreyansh26/llm-sampling
A collection of various LLM sampling methods implemented in pure Pytorch
https://github.com/shreyansh26/llm-sampling
llm llm-inference sampling-methods torch transformers
Last synced: 11 months ago
JSON representation
A collection of various LLM sampling methods implemented in pure Pytorch
- Host: GitHub
- URL: https://github.com/shreyansh26/llm-sampling
- Owner: shreyansh26
- Created: 2024-09-28T10:04:39.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2024-12-09T09:40:09.000Z (about 1 year ago)
- Last Synced: 2025-03-26T00:51:17.194Z (11 months ago)
- Topics: llm, llm-inference, sampling-methods, torch, transformers
- Language: Python
- Homepage:
- Size: 127 KB
- Stars: 21
- Watchers: 1
- Forks: 3
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# LLM Sampling Library
## Overview
The LLM Sampling Library is a Python package designed to facilitate text generation using various sampling methods with large language models (LLMs). This library provides a simple command-line interface (CLI) for users to generate text based on input prompts, utilizing models from the Hugging Face Transformers library.
## Features
- **Multiple Sampling Methods**: Implements various sampling techniques including:
- Unconstrained sampling
- Top-k sampling
- Top-p (nucleus) sampling
- [Min-p sampling](https://arxiv.org/abs/2407.01082)
- [Typical sampling](https://arxiv.org/abs/2202.00666)
- [Epsilon sampling](https://arxiv.org/abs/2210.15191)
- [Eta sampling](https://arxiv.org/abs/2210.15191)
- Beam search
- [Chain-of-Thought (CoT) decoding](https://arxiv.org/abs/2402.10200)
- Constrained JSON decoding
- [Speculative sampling](https://arxiv.org/abs/2302.01318)
- [Medusa decoding](https://arxiv.org/abs/2401.10774)
- **Chat Template Support**: Optionally apply chat templates for Instruct models
- **Memory Efficient**: Uses KV-cache for better memory usage
In all implementations, temperature scaling is applied to the logits before any sampling methods, following the [GPT-2 implementation](https://github.com/openai/gpt-2/blob/9b63575ef42771a015060c964af2c3da4cf7c8ab/src/sample.py#L61C5-L72C14) and [HuggingFace's implementation](https://github.com/huggingface/transformers/blob/acc394c4f5e1283c19783581790b3dc3105a3697/src/transformers/generation/utils.py#L825C9-L847C23).
## Installation
```bash
pip install -r requirements.txt
```
## Usage
Basic usage:
```bash
python generate.py --model --prompt "" --apply-chat-template --temperature --method --max_new_tokens --hf-token --dtype
```
### Parameters
- `--model`: The path or name of the Hugging Face model to use
- `--prompt`: The input sequence for the model
- `--prompt_file`: Alternative to --prompt, load prompt from a file
- `--temperature`: Sampling temperature (default: 1.0)
- `--method`: Sampling method to use (see list below)
- `--max_new_tokens`: Maximum number of new tokens to generate (default: 500)
- `--hf-token`: Your Hugging Face token for model access
- `--dtype`: Data type for model (bfloat16, float16, float32)
- `--seed`: Random seed for reproducibility
Method-specific parameters:
- `--top_k`: K value for top-k sampling
- `--top_p`: P value for nucleus sampling
- `--min_p`: Threshold for min-p sampling
- `--epsilon`: Epsilon value for epsilon/eta sampling
- `--beam_width`: Beam width for beam search
- `--typical_p_mass`: Mass parameter for typical sampling
- `--json_schema`: Schema file path for constrained JSON sampling
- `--draft-model`: Path to draft model for speculative sampling
- `--medusa-model-heads`: Path to Medusa model heads
- `--lookahead`: Lookahead parameter for speculative sampling
## Examples
Top-k sampling:
```bash
python generate.py --model meta-llama/Llama-3.1-8B-Instruct --prompt "Tell me a story" --method top_k --top_k 50 --temperature 0.7
```
Nucleus sampling:
```bash
python generate.py --model meta-llama/Llama-3.1-8B-Instruct --prompt "Write a poem" --method top_p --top_p 0.9 --temperature 0.8
```
Speculative sampling:
```bash
python generate.py --model meta-llama/Llama-3.1-8B-Instruct --prompt "Explain quantum physics" --method speculative --draft-model meta-llama/Llama-3.2-1B-Instruct --lookahead 4
```
Medusa decoding:
```bash
python generate.py --model meta-llama/Llama-3.1-8B-Instruct --prompt "Write code for merge sort" --method medusa --medusa-model-heads
```
### Acknowledgements
- [Medusa](https://github.com/FasterDecoding/Medusa) for the Medusa decoding implementation
- [vLLM](https://github.com/vllm-project/vllm) used as the reference for the Medusa heads architecture
- [Jsonformer](https://github.com/1rgs/jsonformer) for the constrained JSON decoding implementation
- [HuggingFace Transformers](https://github.com/huggingface/transformers) for correctness evaluation