https://github.com/murrellgroup/logitsamplers.jl
https://github.com/murrellgroup/logitsamplers.jl
Last synced: 5 months ago
JSON representation
- Host: GitHub
- URL: https://github.com/murrellgroup/logitsamplers.jl
- Owner: MurrellGroup
- License: mit
- Created: 2024-11-25T09:43:22.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2024-12-01T14:39:24.000Z (over 1 year ago)
- Last Synced: 2025-02-20T20:06:30.262Z (over 1 year ago)
- Language: Julia
- Homepage:
- Size: 228 KB
- Stars: 1
- Watchers: 2
- Forks: 0
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# LogitSamplers
[](https://MurrellGroup.github.io/LogitSamplers.jl/stable/)
[](https://MurrellGroup.github.io/LogitSamplers.jl/dev/)
[](https://github.com/MurrellGroup/LogitSamplers.jl/actions/workflows/CI.yml?query=branch%3Amain)
[](https://codecov.io/gh/MurrellGroup/LogitSamplers.jl)
A Julia package for GPU-friendly sampling from logit distributions with various transformation methods commonly used in language models.
## Usage
The package provides a set of logit transforms to modify the distributions in the log domain.
```julia
using LogitSamplers
# Create a temperature transform
temperature = Temperature(1.5)
# Create a top-p transform
top_p = Top_p(0.5)
# Compose a function that first applies temperature, then top-p
transform = top_p ∘ temperature
# Create a token index sampler function from the transform
sampler = logitsample ∘ transform
# or equivalently:
sampler = logits -> logitsample(top_p(temperature(logits)))
logits = randn(100)
# Get token probabilities with the transformed logits
probs = softmax(transform(logits))
# Sample a logit index from the sampler
index = sampler(logits)
```