{"id":27344449,"url":"https://github.com/JT-Ushio/MHA2MLA","last_synced_at":"2025-04-12T17:06:23.519Z","repository":{"id":278733367,"uuid":"902416128","full_name":"JT-Ushio/MHA2MLA","owner":"JT-Ushio","description":"Towards Economical Inference: Enabling DeepSeek's Multi-Head Latent Attention in Any Transformer-based LLMs","archived":false,"fork":false,"pushed_at":"2025-04-11T04:50:00.000Z","size":88133,"stargazers_count":158,"open_issues_count":6,"forks_count":18,"subscribers_count":1,"default_branch":"main","last_synced_at":"2025-04-11T05:36:49.370Z","etag":null,"topics":["economical-key-value-cache","efficient-attention-architectures"],"latest_commit_sha":null,"homepage":"https://arxiv.org/abs/2502.14837","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/JT-Ushio.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null}},"created_at":"2024-12-12T14:23:27.000Z","updated_at":"2025-04-09T05:41:47.000Z","dependencies_parsed_at":null,"dependency_job_id":"9a9f5e42-29db-445b-993f-08621518a1af","html_url":"https://github.com/JT-Ushio/MHA2MLA","commit_stats":null,"previous_names":["jt-ushio/mha2mla"],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/JT-Ushio%2FMHA2MLA","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/JT-Ushio%2FMHA2MLA/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/JT-Ushio%2FMHA2MLA/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/JT-Ushio%2FMHA2MLA/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/JT-Ushio","download_url":"https://codeload.github.com/JT-Ushio/MHA2MLA/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":248602310,"owners_count":21131615,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["economical-key-value-cache","efficient-attention-architectures"],"created_at":"2025-04-12T17:02:15.942Z","updated_at":"2025-04-12T17:06:23.505Z","avatar_url":"https://github.com/JT-Ushio.png","language":"Python","funding_links":[],"categories":["A01_文本生成_文本对话"],"sub_categories":["大语言对话模型及数据"],"readme":"# MHA2MLA\n\nThis repo contains the code for the paper [\"Towards Economical Inference: Enabling DeepSeek's Multi-Head Latent Attention in Any Transformer-based LLMs\"](https://arxiv.org/abs/2502.14837).\n\n![alt text](img/overview.png)\n\n## News\n\n- [2025.03.12] Released the inference code implemented using **PyTorch** (support for [FlashMLA](https://github.com/deepseek-ai/FlashMLA) inference requires additional development time). \n- [2025.03.04] The four [MLA checkpoints](https://huggingface.co/collections/fnlp/mha2mla-67c51287dfc6cd46127e1b92) ($d_{kv}$=8/16/32/128) derived from `SmolLM-135M/360M/1B7` are publicly available.\n- [2025.03.03] The four [MLA checkpoints](https://huggingface.co/collections/fnlp/mha2mla-67c51287dfc6cd46127e1b92) ($d_{kv}$=16/32/64/256) derived from `Llama-2-7B` are publicly available.\n- [2025.02.21] The paper of MHA2MLA is publicly available: https://arxiv.org/abs/2502.14837\n- [2025.02.19] Released the first version of the MHA2MLA code, providing usage code for Llama fine-tuning and evaluating.\n\n## TO-DO\n\n- [ ] ~~Provide the code for incorporating the projection matrix and inference.~~\n- [ ] Thanks to DeepSeek for open-sourcing the [FlashMLA](https://github.com/deepseek-ai/FlashMLA) inference framework. It’s theoretically possible to save more GPU memory usage using this framework. Let’s see how economical MHA2MLA + FlashMLA (+ KV quanto) can be!\n- [x] Release the code of MHA2MLA based on HuggingFace `Transformers`\n\n## Models\n\n- SmolLM: https://huggingface.co/blog/smollm\n- Llama-2-7b-hf: https://huggingface.co/meta-llama/Llama-2-7b-hf\n\n## Datasets\n\nFirst download the datasets.\n\n- smollm-corpus(fineweb-edu-dedup, cosmopedia-v2, python-edu): https://huggingface.co/datasets/HuggingFaceTB/smollm-corpus\n- open-web-math: https://huggingface.co/datasets/open-web-math/open-web-math\n- stackoverflow: https://huggingface.co/datasets/bigcode/stackoverflow-clean\n\nSecondly, process the datasets according to https://github.com/huggingface/nanotron/blob/main/docs/nanoset.md.\n\n## Environment\n\nInstall pytorch and other packages.\n\n```sh\nconda create -n mla-ft python=3.11\npip install torch==2.4.0 torchvision==0.19.0\npip install -r requirements.txt\n```\n\n## MHA2MLA Fine-Tuning with huggingface transformers\n\n\u003e The research presented in our paper was conducted using [nanotron](https://github.com/huggingface/nanotron) framework. Since there are differences between `transformers` and `nanotron`, hyperparameter search might be necessary. For exact reproduction of the paper's results, we recommend using nanotron for fine tuneing which refer to [**Our README for MHA2MLA using nanotron**](./src/mha2mla_nt/README.md).\n\nFirst, prepare three configuration files:\n1. A general configuration file referencing [135M_4GPU.yaml](./configs_hf/rope/135M_4GPU.yaml)\n2. A partial-RoPE configuration file referencing [rope_v4_topk4.yaml](./configs_hf/rope/rope_v4_topk4.yaml)\n3. A SVD configuration file referencing [svd_method7_rank8.yaml](./configs_hf/rope/svd_method7_rank8.yaml)\n\nThe available strategies for each method are listed below:\n\n| Partial-RoPE version | Strategy                       |\n| :------------------: | ------------------------------ |\n|          0           | full-RoPE                      |\n|          1           | $\\mathcal{S}_{\\text{high}}$    |\n|          2           | $\\mathcal{S}_{\\text{uniform}}$ |\n|          4           | $\\mathcal{S}_{\\text{2-norm}}$  |\n|          5           | $\\mathcal{S}_{\\text{low}}$   |\n\n| SVD version | Strategy          |\n| :---------: | ---------------- |\n|      2      | $SVD_{split}$ |\n|      7      | $SVD_{joint}$ |\n\nThen, use the following command for MLA fine-tuning:\n```sh\ntorchrun --nproc_per_node 4 \\\n    ../src/mha2mla/run_train.py \\\n    --config_file ../configs_hf/rope/135M_4GPU.yaml \\\n    --partial_rope_config ../configs_hf/rope/rope_v4_topk4.yaml \\\n    --svd_config ../configs_hf/rope/svd_method7_rank8.yaml\n```\n\n\n\u003e If you want to use the partial-RoPE version 4, you should get the `qk_tensor` first.\n\u003e Using the following command, you can get the `qk_tensor`:\n\u003e\n\u003e ```sh\n\u003e torchrun --nproc_per_node 1 \\\n\u003e     ../src/mha2mla/2_norm.py \\\n\u003e     --config_file ../configs_hf/rope/135M_4GPU.yaml \\\n\u003e     --output_dir ./qk_tensor_hf_test.pth \\\n\u003e     --sample_size 1024\n\u003e ```\n\n## Lighteval Evaluation\n\nFor the MLA evaluation, you can use the following command:\n\n```sh\naccelerate launch --multi_gpu --num_processes=4 \\\n    ../src/mha2mla/eval.py --is_mla \\\n    accelerate \\\n    --model_args \"pretrained=${model_name_or_path},revision=main,dtype=bfloat16,max_length=2048\" \\\n    --override_batch_size 48 \\\n    --custom_tasks \"../src/mha2mla/tasks.py\" \\\n    --tasks \"../src/mha2mla/smollm1_base.txt\" \\\n    --output_dir \"../eval_results/\"\n```\n\u003e If you want to evaluate the `partial_rope` ckpt without `low rank approx`, you should change `--is_mla` to `--is_partial_rope`.\n\n## LongBench Evaluation\n\nFor the baseline evaluation, you can use the following command:\n\n```sh\ntorchrun --nproc_per_node=4 \\\n    ../src/mha2mla/longbench.py \\\n    --model_path ${model_name_or_path} \\\n    --tokenizer_path ${model_name_or_path} \\\n    --longbench True \\\n    --lb_max_tokens 2048 \\\n    --lb_batch_size 16 \\\n    --output_dir /longbench/bf16 \\\n    --dtype \"bfloat16\"\n```\n\nFor the MLA model, you should add the parameter `--is_mla` to the command.\n\nIf you want to use the quantized KV cache, you can use the following command:\n\n```sh\ntorchrun --nproc_per_node=4 \\\n    ../src/mha2mla/longbench.py \\\n    --model_path ${model_name_or_path} \\\n    --tokenizer_path ${model_name_or_path} \\\n    --longbench True \\\n    --lb_max_tokens 2048 \\\n    --lb_batch_size 16 \\\n    --output_dir /longbench/${model_name_or_path}_hqq_int4 \\\n    --dtype \"bfloat16\" \\\n    --cache_implementation \"quantized\" \\\n    --backend \"HQQ\" \\\n    --nbits 4 \\\n    --residual_length 128 \\\n```\n\n## Inference\n\n- Step 1: Download the [**monkey patch file**](src/mha2mla/monkey_patch.py).\n```shell\nwget https://raw.githubusercontent.com/JT-Ushio/MHA2MLA/refs/heads/main/src/mha2mla/monkey_patch.py\n```\n\n- Step 2(Option): For MHA2MLA models using Partial-RoPE 2-nrom method, Download the [**qk_2-norm file**](./utils/). \nTake `qk_tensor_1.7B.pth` as an example:\n```shell\nwget https://github.com/JT-Ushio/MHA2MLA/raw/refs/heads/main/utils/qk_tensor_1.7B.pth\n```\n\n- Step 3: Download the [MHA2MLA models](https://huggingface.co/collections/fnlp/mha2mla-67c51287dfc6cd46127e1b92) and run inference. \nTake `fnlp/SmolLM-1B7-MLA-d_kv_16` as an example:\n\n```python\nimport torch\nfrom transformers import AutoConfig, AutoTokenizer, LlamaForCausalLM\nfrom monkey_patch import infer_monkey_patch\n\nmodel_name = \"fnlp/SmolLM-1B7-MLA-d_kv_16\"\n\n# Monkey Patch: MHA -\u003e MLA\nconfig = AutoConfig.from_pretrained(model_name)\nif \"RoPE\" in config:\n    config.RoPE[\"qk_tensor_path\"] = \"qk_tensor_1.7B.pth\"  # Configuration for Specific Models\n    infer_monkey_patch(config.RoPE)\n\ntokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)\nmodel = LlamaForCausalLM.from_pretrained(model_name, config=config, torch_dtype=torch.bfloat16).cuda()\n\n# Generate\ntext = \"Which American-born Sinclair won the Nobel Prize for Literature in 1930?\"\ninputs = tokenizer(text, return_tensors=\"pt\").to(model.device)\ngeneration_kwargs = {\"do_sample\": False, \"use_cache\": True, \"max_new_tokens\": 128}\noutput = model.generate(**inputs, **generation_kwargs)\n\nprint(tokenizer.decode(output[0], skip_special_tokens=True))\n# - Sinclair Lewis\n```\n\n## Citation\n```\n@misc{ji2025economicalinferenceenablingdeepseeks,\n      title={Towards Economical Inference: Enabling DeepSeek's Multi-Head Latent Attention in Any Transformer-based LLMs}, \n      author={Tao Ji and Bin Guo and Yuanbin Wu and Qipeng Guo and Lixing Shen and Zhan Chen and Xipeng Qiu and Qi Zhang and Tao Gui},\n      year={2025},\n      eprint={2502.14837},\n      archivePrefix={arXiv},\n      primaryClass={cs.CL},\n      url={https://arxiv.org/abs/2502.14837}, \n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2FJT-Ushio%2FMHA2MLA","html_url":"https://awesome.ecosyste.ms/projects/github.com%2FJT-Ushio%2FMHA2MLA","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2FJT-Ushio%2FMHA2MLA/lists"}