https://github.com/JT-Ushio/MHA2MLA
Towards Economical Inference: Enabling DeepSeek's Multi-Head Latent Attention in Any Transformer-based LLMs
https://github.com/JT-Ushio/MHA2MLA
economical-key-value-cache efficient-attention-architectures
Last synced: 22 days ago
JSON representation
Towards Economical Inference: Enabling DeepSeek's Multi-Head Latent Attention in Any Transformer-based LLMs
- Host: GitHub
- URL: https://github.com/JT-Ushio/MHA2MLA
- Owner: JT-Ushio
- License: apache-2.0
- Created: 2024-12-12T14:23:27.000Z (5 months ago)
- Default Branch: main
- Last Pushed: 2025-04-11T04:50:00.000Z (24 days ago)
- Last Synced: 2025-04-11T05:36:49.370Z (24 days ago)
- Topics: economical-key-value-cache, efficient-attention-architectures
- Language: Python
- Homepage: https://arxiv.org/abs/2502.14837
- Size: 84.1 MB
- Stars: 158
- Watchers: 1
- Forks: 18
- Open Issues: 6
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
- StarryDivineSky - JT-Ushio/MHA2MLA - Head Latent Attention (MLA),从而实现更经济的推理。该项目通过将标准Multi-Head Attention (MHA) 替换为MLA来降低计算成本,尤其是在长序列推理中。MLA的核心思想是利用低秩矩阵来近似注意力矩阵,从而减少计算量和内存占用。该项目提供了详细的理论解释和代码实现,方便用户在自己的模型中集成MLA。它支持PyTorch框架,并提供了示例代码和实验结果,展示了MLA在不同模型上的性能提升。该项目的目标是让更多开发者能够利用MLA的优势,构建更高效的LLM应用。具体来说,它通过学习一个低维潜在空间来压缩注意力信息,从而减少计算复杂度。该项目还提供了评估工具,用于比较MHA和MLA在推理速度和准确性方面的差异。总体而言,MHA2MLA提供了一种实用的方法,可以在不显著降低模型性能的情况下,显著提高LLM的推理效率。 (A01_文本生成_文本对话 / 大语言对话模型及数据)
README
# MHA2MLA
This repo contains the code for the paper ["Towards Economical Inference: Enabling DeepSeek's Multi-Head Latent Attention in Any Transformer-based LLMs"](https://arxiv.org/abs/2502.14837).

## News
- [2025.03.12] Released the inference code implemented using **PyTorch** (support for [FlashMLA](https://github.com/deepseek-ai/FlashMLA) inference requires additional development time).
- [2025.03.04] The four [MLA checkpoints](https://huggingface.co/collections/fnlp/mha2mla-67c51287dfc6cd46127e1b92) ($d_{kv}$=8/16/32/128) derived from `SmolLM-135M/360M/1B7` are publicly available.
- [2025.03.03] The four [MLA checkpoints](https://huggingface.co/collections/fnlp/mha2mla-67c51287dfc6cd46127e1b92) ($d_{kv}$=16/32/64/256) derived from `Llama-2-7B` are publicly available.
- [2025.02.21] The paper of MHA2MLA is publicly available: https://arxiv.org/abs/2502.14837
- [2025.02.19] Released the first version of the MHA2MLA code, providing usage code for Llama fine-tuning and evaluating.## TO-DO
- [ ] ~~Provide the code for incorporating the projection matrix and inference.~~
- [ ] Thanks to DeepSeek for open-sourcing the [FlashMLA](https://github.com/deepseek-ai/FlashMLA) inference framework. It’s theoretically possible to save more GPU memory usage using this framework. Let’s see how economical MHA2MLA + FlashMLA (+ KV quanto) can be!
- [x] Release the code of MHA2MLA based on HuggingFace `Transformers`## Models
- SmolLM: https://huggingface.co/blog/smollm
- Llama-2-7b-hf: https://huggingface.co/meta-llama/Llama-2-7b-hf## Datasets
First download the datasets.
- smollm-corpus(fineweb-edu-dedup, cosmopedia-v2, python-edu): https://huggingface.co/datasets/HuggingFaceTB/smollm-corpus
- open-web-math: https://huggingface.co/datasets/open-web-math/open-web-math
- stackoverflow: https://huggingface.co/datasets/bigcode/stackoverflow-cleanSecondly, process the datasets according to https://github.com/huggingface/nanotron/blob/main/docs/nanoset.md.
## Environment
Install pytorch and other packages.
```sh
conda create -n mla-ft python=3.11
pip install torch==2.4.0 torchvision==0.19.0
pip install -r requirements.txt
```## MHA2MLA Fine-Tuning with huggingface transformers
> The research presented in our paper was conducted using [nanotron](https://github.com/huggingface/nanotron) framework. Since there are differences between `transformers` and `nanotron`, hyperparameter search might be necessary. For exact reproduction of the paper's results, we recommend using nanotron for fine tuneing which refer to [**Our README for MHA2MLA using nanotron**](./src/mha2mla_nt/README.md).
First, prepare three configuration files:
1. A general configuration file referencing [135M_4GPU.yaml](./configs_hf/rope/135M_4GPU.yaml)
2. A partial-RoPE configuration file referencing [rope_v4_topk4.yaml](./configs_hf/rope/rope_v4_topk4.yaml)
3. A SVD configuration file referencing [svd_method7_rank8.yaml](./configs_hf/rope/svd_method7_rank8.yaml)The available strategies for each method are listed below:
| Partial-RoPE version | Strategy |
| :------------------: | ------------------------------ |
| 0 | full-RoPE |
| 1 | $\mathcal{S}_{\text{high}}$ |
| 2 | $\mathcal{S}_{\text{uniform}}$ |
| 4 | $\mathcal{S}_{\text{2-norm}}$ |
| 5 | $\mathcal{S}_{\text{low}}$ || SVD version | Strategy |
| :---------: | ---------------- |
| 2 | $SVD_{split}$ |
| 7 | $SVD_{joint}$ |Then, use the following command for MLA fine-tuning:
```sh
torchrun --nproc_per_node 4 \
../src/mha2mla/run_train.py \
--config_file ../configs_hf/rope/135M_4GPU.yaml \
--partial_rope_config ../configs_hf/rope/rope_v4_topk4.yaml \
--svd_config ../configs_hf/rope/svd_method7_rank8.yaml
```> If you want to use the partial-RoPE version 4, you should get the `qk_tensor` first.
> Using the following command, you can get the `qk_tensor`:
>
> ```sh
> torchrun --nproc_per_node 1 \
> ../src/mha2mla/2_norm.py \
> --config_file ../configs_hf/rope/135M_4GPU.yaml \
> --output_dir ./qk_tensor_hf_test.pth \
> --sample_size 1024
> ```## Lighteval Evaluation
For the MLA evaluation, you can use the following command:
```sh
accelerate launch --multi_gpu --num_processes=4 \
../src/mha2mla/eval.py --is_mla \
accelerate \
--model_args "pretrained=${model_name_or_path},revision=main,dtype=bfloat16,max_length=2048" \
--override_batch_size 48 \
--custom_tasks "../src/mha2mla/tasks.py" \
--tasks "../src/mha2mla/smollm1_base.txt" \
--output_dir "../eval_results/"
```
> If you want to evaluate the `partial_rope` ckpt without `low rank approx`, you should change `--is_mla` to `--is_partial_rope`.## LongBench Evaluation
For the baseline evaluation, you can use the following command:
```sh
torchrun --nproc_per_node=4 \
../src/mha2mla/longbench.py \
--model_path ${model_name_or_path} \
--tokenizer_path ${model_name_or_path} \
--longbench True \
--lb_max_tokens 2048 \
--lb_batch_size 16 \
--output_dir /longbench/bf16 \
--dtype "bfloat16"
```For the MLA model, you should add the parameter `--is_mla` to the command.
If you want to use the quantized KV cache, you can use the following command:
```sh
torchrun --nproc_per_node=4 \
../src/mha2mla/longbench.py \
--model_path ${model_name_or_path} \
--tokenizer_path ${model_name_or_path} \
--longbench True \
--lb_max_tokens 2048 \
--lb_batch_size 16 \
--output_dir /longbench/${model_name_or_path}_hqq_int4 \
--dtype "bfloat16" \
--cache_implementation "quantized" \
--backend "HQQ" \
--nbits 4 \
--residual_length 128 \
```## Inference
- Step 1: Download the [**monkey patch file**](src/mha2mla/monkey_patch.py).
```shell
wget https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/refs/heads/main/src/mha2mla/monkey_patch.py
```- Step 2(Option): For MHA2MLA models using Partial-RoPE 2-nrom method, Download the [**qk_2-norm file**](./utils/).
Take `qk_tensor_1.7B.pth` as an example:
```shell
wget https://github.com/JT-Ushio/MHA2MLA/raw/refs/heads/main/utils/qk_tensor_1.7B.pth
```- Step 3: Download the [MHA2MLA models](https://huggingface.co/collections/fnlp/mha2mla-67c51287dfc6cd46127e1b92) and run inference.
Take `fnlp/SmolLM-1B7-MLA-d_kv_16` as an example:```python
import torch
from transformers import AutoConfig, AutoTokenizer, LlamaForCausalLM
from monkey_patch import infer_monkey_patchmodel_name = "fnlp/SmolLM-1B7-MLA-d_kv_16"
# Monkey Patch: MHA -> MLA
config = AutoConfig.from_pretrained(model_name)
if "RoPE" in config:
config.RoPE["qk_tensor_path"] = "qk_tensor_1.7B.pth" # Configuration for Specific Models
infer_monkey_patch(config.RoPE)tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = LlamaForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.bfloat16).cuda()# Generate
text = "Which American-born Sinclair won the Nobel Prize for Literature in 1930?"
inputs = tokenizer(text, return_tensors="pt").to(model.device)
generation_kwargs = {"do_sample": False, "use_cache": True, "max_new_tokens": 128}
output = model.generate(**inputs, **generation_kwargs)print(tokenizer.decode(output[0], skip_special_tokens=True))
# - Sinclair Lewis
```## Citation
```
@misc{ji2025economicalinferenceenablingdeepseeks,
title={Towards Economical Inference: Enabling DeepSeek's Multi-Head Latent Attention in Any Transformer-based LLMs},
author={Tao Ji and Bin Guo and Yuanbin Wu and Qipeng Guo and Lixing Shen and Zhan Chen and Xipeng Qiu and Qi Zhang and Tao Gui},
year={2025},
eprint={2502.14837},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2502.14837},
}
```