https://github.com/linkedin/Liger-Kernel
Efficient Triton Kernels for LLM Training
https://github.com/linkedin/Liger-Kernel
finetuning gemma2 llama llama3 llm-training llms mistral phi3 triton triton-kernels
Last synced: about 2 months ago
JSON representation
Efficient Triton Kernels for LLM Training
- Host: GitHub
- URL: https://github.com/linkedin/Liger-Kernel
- Owner: linkedin
- License: bsd-2-clause
- Created: 2024-08-06T17:47:52.000Z (about 1 year ago)
- Default Branch: main
- Last Pushed: 2024-12-13T00:00:32.000Z (10 months ago)
- Last Synced: 2024-12-13T01:17:26.920Z (10 months ago)
- Topics: finetuning, gemma2, llama, llama3, llm-training, llms, mistral, phi3, triton, triton-kernels
- Language: Python
- Homepage: https://arxiv.org/pdf/2410.10989
- Size: 14.1 MB
- Stars: 3,809
- Watchers: 40
- Forks: 228
- Open Issues: 70
-
Metadata Files:
- Readme: README.md
- Contributing: docs/CONTRIBUTING.md
- License: LICENSE
Awesome Lists containing this project
- StarryDivineSky - linkedin/Liger-Kernel - Attn 相同的精神,但适用于 RMSNorm、RoPE、SwiGLU 和 CrossEntropy!通过内核融合、就地替换和分块技术,将多 GPU 训练吞吐量提高 20%,并将内存使用量降低 60%。确切:计算是精确的 - 没有近似值!前向和后向传递均通过严格的单元测试实现,并针对没有 Liger 内核的训练运行进行收敛测试,以确保准确性。轻:Liger Kernel 的依赖项最少,只需要 Torch 和 Triton,不需要额外的库!告别依赖性头痛!支持多 GPU:与多 GPU 设置(PyTorch FSDP、DeepSpeed、DDP 等)兼容。Trainer 框架集成:Axolotl、LLaMa-Factory、SFTTrainer、Hugging Face Trainer、SWIFT (A01_文本生成_文本对话 / 大语言对话模型及数据)
- awesome-production-machine-learning - Liger Kernel - Kernel.svg?style=social) - Liger Kernel is a collection of Triton kernels designed specifically for LLM training. (Computation and Communication Optimisation)
- awesome-LLM-resources - Liger-Kernel
- Awesome-Local-LLM - _Git_
- Awesome-Local-LLM - _Git_
- awesome-llm-and-aigc - Liger-Kernel - Kernel?style=social"/> : Efficient Triton Kernels for LLM Training. [arxiv.org/pdf/2410.10989](https://arxiv.org/pdf/2410.10989) (Summary)
- awesome-cuda-and-hpc - Liger-Kernel - Kernel?style=social"/> : Efficient Triton Kernels for LLM Training. [arxiv.org/pdf/2410.10989](https://arxiv.org/pdf/2410.10989) (Frameworks)
- awesome-llm-pretraining - [code
- awesome - linkedin/Liger-Kernel - Efficient Triton Kernels for LLM Training (Python)
README
# Liger Kernel: Efficient Triton Kernels for LLM Training
[Installation](#installation) | [Getting Started](#getting-started) | [Examples](#examples) | [High-level APIs](#high-level-apis) | [Low-level APIs](#low-level-apis) | [Cite our work](#cite-this-work)
Latest News 🔥
- [2025/03/06] We release a joint blog post on TorchTune × Liger - [Peak Performance, Minimized Memory: Optimizing torchtune’s performance with torch.compile & Liger Kernel](https://pytorch.org/blog/peak-performance-minimized-memory/)
- [2024/12/11] We release [v0.5.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.5.0): 80% more memory efficient post training losses (DPO, ORPO, CPO, etc)!
- [2024/12/5] We release LinkedIn Engineering Blog - [Liger-Kernel: Empowering an open source ecosystem of Triton Kernels for Efficient LLM Training](https://www.linkedin.com/blog/engineering/open-source/liger-kernel-open-source-ecosystem-for-efficient-llm-training)
- [2024/11/6] We release [v0.4.0](https://github.com/linkedin/Liger-Kernel/releases/tag/v0.4.0): Full AMD support, Tech Report, Modal CI, Llama-3.2-Vision!
- [2024/10/21] We have released the tech report of Liger Kernel on Arxiv: https://arxiv.org/pdf/2410.10989
- [2024/9/6] We release v0.2.1 ([X post](https://x.com/liger_kernel/status/1832168197002510649)). 2500+ Stars, 10+ New Contributors, 50+ PRs, 50k Downloads in two weeks!
- [2024/8/31] CUDA MODE talk, [Liger-Kernel: Real-world Triton kernel for LLM Training](https://youtu.be/gWble4FreV4?si=dxPeIchhkJ36Mbns), [Slides](https://github.com/cuda-mode/lectures?tab=readme-ov-file#lecture-28-liger-kernel)
- [2024/8/23] Official release: check out our [X post](https://x.com/hsu_byron/status/1827072737673982056)**Liger Kernel** is a collection of Triton kernels designed specifically for LLM training. It can effectively increase multi-GPU **training throughput by 20%** and reduces **memory usage by 60%**. We have implemented **Hugging Face Compatible** `RMSNorm`, `RoPE`, `SwiGLU`, `CrossEntropy`, `FusedLinearCrossEntropy`, and more to come. The kernel works out of the box with [Flash Attention](https://github.com/Dao-AILab/flash-attention), [PyTorch FSDP](https://pytorch.org/tutorials/intermediate/FSDP_tutorial.html), and [Microsoft DeepSpeed](https://github.com/microsoft/DeepSpeed). We welcome contributions from the community to gather the best kernels for LLM training.
We've also added optimized Post-Training kernels that deliver **up to 80% memory savings** for alignment and distillation tasks. We support losses like DPO, CPO, ORPO, SimPO, KTO, JSD, and many more. Check out [how we optimize the memory](https://x.com/hsu_byron/status/1866577403918917655).
You can view the documentation site for additional installation, usage examples, and API references:https://linkedin.github.io/Liger-Kernel/
## Supercharge Your Model with Liger Kernel

With one line of code, Liger Kernel can increase throughput by more than 20% and reduce memory usage by 60%, thereby enabling longer context lengths, larger batch sizes, and massive vocabularies.
| Speed Up | Memory Reduction |
|--------------------------|-------------------------|
|  |  |> **Note:**
> - Benchmark conditions: LLaMA 3-8B, Batch Size = 8, Data Type = `bf16`, Optimizer = AdamW, Gradient Checkpointing = True, Distributed Strategy = FSDP1 on 8 A100s.
> - Hugging Face models start to OOM at a 4K context length, whereas Hugging Face + Liger Kernel scales up to 16K.## Optimize Post Training with Liger Kernel
![]()
We provide optimized post training kernels like DPO, ORPO, SimPO, and more which can reduce memory usage by up to 80%. You can easily use them as python modules.
```python
from liger_kernel.chunked_loss import LigerFusedLinearORPOLoss
orpo_loss = LigerFusedLinearORPOLoss()
y = orpo_loss(lm_head.weight, x, target)
```## Examples
| **Use Case** | **Description** |
|------------------------------------------------|---------------------------------------------------------------------------------------------------|
| [**Hugging Face Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface) | Train LLaMA 3-8B ~20% faster with over 40% memory reduction on Alpaca dataset using 4 A100s with FSDP |
| [**Lightning Trainer**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/lightning) | Increase 15% throughput and reduce memory usage by 40% with LLaMA3-8B on MMLU dataset using 8 A100s with DeepSpeed ZeRO3 |
| [**Medusa Multi-head LLM (Retraining Phase)**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/medusa) | Reduce memory usage by 80% with 5 LM heads and improve throughput by 40% using 8 A100s with FSDP |
| [**Vision-Language Model SFT**](https://github.com/linkedin/Liger-Kernel/tree/main/examples/huggingface/run_qwen2_vl.sh) | Finetune Qwen2-VL on image-text data using 4 A100s with FSDP |
| [**Liger ORPO Trainer**](https://github.com/linkedin/Liger-Kernel/blob/main/examples/alignment/run_orpo.py) | Align Llama 3.2 using Liger ORPO Trainer with FSDP with 50% memory reduction |## Key Features
- **Ease of use:** Simply patch your Hugging Face model with one line of code, or compose your own model using our Liger Kernel modules.
- **Time and memory efficient:** In the same spirit as Flash-Attn, but for layers like **RMSNorm**, **RoPE**, **SwiGLU**, and **CrossEntropy**! Increases multi-GPU training throughput by 20% and reduces memory usage by 60% with **kernel fusion**, **in-place replacement**, and **chunking** techniques.
- **Exact:** Computation is exact—no approximations! Both forward and backward passes are implemented with rigorous unit tests and undergo convergence testing against training runs without Liger Kernel to ensure accuracy.
- **Lightweight:** Liger Kernel has minimal dependencies, requiring only Torch and Triton—no extra libraries needed! Say goodbye to dependency headaches!
- **Multi-GPU supported:** Compatible with multi-GPU setups (PyTorch FSDP, DeepSpeed, DDP, etc.).
- **Trainer Framework Integration**: [Axolotl](https://github.com/axolotl-ai-cloud/axolotl), [LLaMa-Factory](https://github.com/hiyouga/LLaMA-Factory), [SFTTrainer](https://github.com/huggingface/trl/releases/tag/v0.10.1), [Hugging Face Trainer](https://github.com/huggingface/transformers/pull/32860), [SWIFT](https://github.com/modelscope/ms-swift), [oumi](https://github.com/oumi-ai/oumi/tree/main)## Installation
### Dependencies
#### CUDA
- `torch >= 2.1.2`
- `triton >= 2.3.0`#### ROCm
- `torch >= 2.5.0` Install according to the instruction in Pytorch official webpage.
- `triton >= 3.0.0` Install from pypi. (e.g. `pip install triton==3.0.0`)```bash
# Need to pass the url when installing
pip install -e .[dev] --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2
```### Optional Dependencies
- `transformers >= 4.x`: Required if you plan to use the transformers models patching APIs. The specific model you are working will dictate the minimum version of transformers.
> **Note:**
> Our kernels inherit the full spectrum of hardware compatibility offered by [Triton](https://github.com/triton-lang/triton).To install the stable version:
```bash
$ pip install liger-kernel
```To install the nightly version:
```bash
$ pip install liger-kernel-nightly
```To install from source:
```bash
git clone https://github.com/linkedin/Liger-Kernel.git
cd Liger-Kernel# Install Default Dependencies
# Setup.py will detect whether you are using AMD or NVIDIA
pip install -e .# Setup Development Dependencies
pip install -e ".[dev]"
```## Getting Started
There are a couple of ways to apply Liger kernels, depending on the level of customization required.
### 1. Use AutoLigerKernelForCausalLM
Using the `AutoLigerKernelForCausalLM` is the simplest approach, as you don't have to import a model-specific patching API. If the model type is supported, the modeling code will be automatically patched using the default settings.
```python
from liger_kernel.transformers import AutoLigerKernelForCausalLM# This AutoModel wrapper class automatically monkey-patches the
# model with the optimized Liger kernels if the model is supported.
model = AutoLigerKernelForCausalLM.from_pretrained("path/to/some/model")
```### 2. Apply Model-Specific Patching APIs
Using the [patching APIs](#patching), you can swap Hugging Face models with optimized Liger Kernels.
```python
import transformers
from liger_kernel.transformers import apply_liger_kernel_to_llama# 1a. Adding this line automatically monkey-patches the model with the optimized Liger kernels
apply_liger_kernel_to_llama()# 1b. You could alternatively specify exactly which kernels are applied
apply_liger_kernel_to_llama(
rope=True,
swiglu=True,
cross_entropy=True,
fused_linear_cross_entropy=False,
rms_norm=False
)# 2. Instantiate patched model
model = transformers.AutoModelForCausalLM("path/to/llama/model")
```### 3. Compose Your Own Model
You can take individual [kernels](https://github.com/linkedin/Liger-Kernel?tab=readme-ov-file#model-kernels) to compose your models.
```python
from liger_kernel.transformers import LigerFusedLinearCrossEntropyLoss
import torch.nn as nn
import torchmodel = nn.Linear(128, 256).cuda()
# fuses linear + cross entropy layers together and performs chunk-by-chunk computation to reduce memory
loss_fn = LigerFusedLinearCrossEntropyLoss()input = torch.randn(4, 128, requires_grad=True, device="cuda")
target = torch.randint(256, (4, ), device="cuda")loss = loss_fn(model.weight, input, target)
loss.backward()
```## High-level APIs
### AutoModel
| **AutoModel Variant** | **API** |
|-----------|---------|
| AutoModelForCausalLM | `liger_kernel.transformers.AutoLigerKernelForCausalLM` |### Patching
| **Model** | **API** | **Supported Operations** |
|-------------|--------------------------------------------------------------|-------------------------------------------------------------------------|
| Llama4 (Text) & (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_llama4` | RMSNorm, LayerNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| LLaMA 2 & 3 | `liger_kernel.transformers.apply_liger_kernel_to_llama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| LLaMA 3.2-Vision | `liger_kernel.transformers.apply_liger_kernel_to_mllama` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Mistral | `liger_kernel.transformers.apply_liger_kernel_to_mistral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Mixtral | `liger_kernel.transformers.apply_liger_kernel_to_mixtral` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Gemma1 | `liger_kernel.transformers.apply_liger_kernel_to_gemma` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Gemma2 | `liger_kernel.transformers.apply_liger_kernel_to_gemma2` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Gemma3 (Text) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3_text` | RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Gemma3 (Multimodal) | `liger_kernel.transformers.apply_liger_kernel_to_gemma3` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Paligemma, Paligemma2, & Paligemma2 Mix | `liger_kernel.transformers.apply_liger_kernel_to_paligemma` | LayerNorm, RoPE, RMSNorm, GeGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Qwen2, Qwen2.5, & QwQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Qwen2-VL, & QVQ | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_vl` | RMSNorm, LayerNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Qwen2.5-VL | `liger_kernel.transformers.apply_liger_kernel_to_qwen2_5_vl` | RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Qwen3 | `liger_kernel.transformers.apply_liger_kernel_to_qwen3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Qwen3 MoE | `liger_kernel.transformers.apply_liger_kernel_to_qwen3_moe` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Phi3 & Phi3.5 | `liger_kernel.transformers.apply_liger_kernel_to_phi3` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| Granite 3.0 & 3.1 | `liger_kernel.transformers.apply_liger_kernel_to_granite` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss |
| OLMo2 | `liger_kernel.transformers.apply_liger_kernel_to_olmo2` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |
| GLM-4 | `liger_kernel.transformers.apply_liger_kernel_to_glm4` | RoPE, RMSNorm, SwiGLU, CrossEntropyLoss, FusedLinearCrossEntropy |## Low-level APIs
- `Fused Linear` kernels combine linear layers with losses, reducing memory usage by up to 80% - ideal for HBM-constrained workloads.
- Other kernels use fusion and in-place techniques for memory and performance optimization.### Model Kernels
| **Kernel** | **API** |
|---------------------------------|-------------------------------------------------------------|
| RMSNorm | `liger_kernel.transformers.LigerRMSNorm` |
| LayerNorm | `liger_kernel.transformers.LigerLayerNorm` |
| RoPE | `liger_kernel.transformers.liger_rotary_pos_emb` |
| SwiGLU | `liger_kernel.transformers.LigerSwiGLUMLP` |
| GeGLU | `liger_kernel.transformers.LigerGEGLUMLP` |
| CrossEntropy | `liger_kernel.transformers.LigerCrossEntropyLoss` |
| Fused Linear CrossEntropy | `liger_kernel.transformers.LigerFusedLinearCrossEntropyLoss`|
| Multi Token Attention | `liger_kernel.transformers.LigerMultiTokenAttention` |
| Softmax | `liger_kernel.transformers.LigerSoftmax` |
| Sparsemax | `liger_kernel.transformers.LigerSparsemax` |### Alignment Kernels
| **Kernel** | **API** |
|---------------------------------|-------------------------------------------------------------|
| Fused Linear CPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearCPOLoss` |
| Fused Linear DPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearDPOLoss` |
| Fused Linear ORPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearORPOLoss` |
| Fused Linear SimPO Loss | `liger_kernel.chunked_loss.LigerFusedLinearSimPOLoss` |
| Fused Linear KTO Loss | `liger_kernel.chunked_loss.LigerFusedLinearKTOLoss` |### Distillation Kernels
| **Kernel** | **API** |
|---------------------------------|-------------------------------------------------------------|
| KLDivergence | `liger_kernel.transformers.LigerKLDIVLoss` |
| JSD | `liger_kernel.transformers.LigerJSD` |
| Fused Linear JSD | `liger_kernel.transformers.LigerFusedLinearJSD` |
| TVD | `liger_kernel.transformers.LigerTVDLoss` |### Experimental Kernels
| **Kernel** | **API** |
|---------------------------------|-------------------------------------------------------------|
| Embedding | `liger_kernel.transformers.experimental.LigerEmbedding` |
| Matmul int2xint8 | `liger_kernel.transformers.experimental.matmul` |## Contributing, Acknowledgements, and License
- [Contributing Guidelines](https://github.com/linkedin/Liger-Kernel/blob/main/docs/contributing.md)
- [Acknowledgements](https://github.com/linkedin/Liger-Kernel/blob/main/docs/acknowledgement.md)
- [License Information](https://github.com/linkedin/Liger-Kernel/blob/main/docs/license.md)## Sponsorship and Collaboration
- [Glows.ai](https://platform.glows.ai/): Sponsoring NVIDIA GPUs for our open source developers.
- [AMD](https://www.amd.com/en.html): Providing AMD GPUs for our AMD CI.
- [Intel](https://www.intel.com/): Providing Intel GPUs for our Intel CI.
- [Modal](https://modal.com/): Free 3000 credits from GPU MODE IRL for our NVIDIA CI.
- [EmbeddedLLM](https://embeddedllm.com/): Making Liger Kernel run fast and stable on AMD.
- [HuggingFace](https://huggingface.co/): Integrating Liger Kernel into Hugging Face Transformers and TRL.
- [Lightning AI](https://lightning.ai/): Integrating Liger Kernel into Lightning Thunder.
- [Axolotl](https://axolotl.ai/): Integrating Liger Kernel into Axolotl.
- [Llama-Factory](https://github.com/hiyouga/LLaMA-Factory): Integrating Liger Kernel into Llama-Factory.## CI status
Build
## Contact
- For issues, create a Github ticket in this repository
- For open discussion, join [our discord channel on GPUMode](https://discord.com/channels/1189498204333543425/1275130785933951039)
- For formal collaboration, send an email to Yanning Chen(yannchen@linkedin.com) and Zhipeng Wang(zhipwang@linkedin.com)## Cite this work
Biblatex entry:
```bib
@inproceedings{
hsu2025ligerkernel,
title={Liger-Kernel: Efficient Triton Kernels for {LLM} Training},
author={Pin-Lun Hsu and Yun Dai and Vignesh Kothapalli and Qingquan Song and Shao Tang and Siyu Zhu and Steven Shimizu and Shivam Sahni and Haowen Ning and Yanning Chen and Zhipeng Wang},
booktitle={Championing Open-source DEvelopment in ML Workshop @ ICML25},
year={2025},
url={https://openreview.net/forum?id=36SjAIT42G}
}
```## Star History
[](https://www.star-history.com/#linkedin/Liger-Kernel&Date)