{"id":13752918,"url":"https://github.com/FlagOpen/FlagAttention","last_synced_at":"2025-05-09T20:34:36.376Z","repository":{"id":199819804,"uuid":"701597525","full_name":"FlagOpen/FlagAttention","owner":"FlagOpen","description":"A collection of memory efficient attention operators implemented in the Triton language.","archived":false,"fork":false,"pushed_at":"2024-06-05T09:41:11.000Z","size":998,"stargazers_count":265,"open_issues_count":6,"forks_count":18,"subscribers_count":6,"default_branch":"main","last_synced_at":"2025-04-28T13:02:50.469Z","etag":null,"topics":["attention","triton-lang"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"other","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/FlagOpen.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2023-10-07T02:44:33.000Z","updated_at":"2025-04-27T06:14:58.000Z","dependencies_parsed_at":"2023-12-18T05:47:11.440Z","dependency_job_id":"b7259723-94c5-43a6-ae71-2cd4fbda0086","html_url":"https://github.com/FlagOpen/FlagAttention","commit_stats":{"total_commits":39,"total_committers":2,"mean_commits":19.5,"dds":"0.28205128205128205","last_synced_commit":"14f5f1a350ad667de7fe4dab7cc46a4b4a28fbfa"},"previous_names":["flagopen/flagattention"],"tags_count":2,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/FlagOpen%2FFlagAttention","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/FlagOpen%2FFlagAttention/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/FlagOpen%2FFlagAttention/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/FlagOpen%2FFlagAttention/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/FlagOpen","download_url":"https://codeload.github.com/FlagOpen/FlagAttention/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":253321797,"owners_count":21890466,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["attention","triton-lang"],"created_at":"2024-08-03T09:01:12.619Z","updated_at":"2025-05-09T20:34:32.128Z","avatar_url":"https://github.com/FlagOpen.png","language":"Python","funding_links":[],"categories":["Transformer库与优化"],"sub_categories":[],"readme":"# FlagAttention\n\n\u003cp align=\"center\"\u003e\n    \u003cimg src=\"./assets/logo/horizontal-blue.png\" width = \"400\" alt=\"flag-attention\" \u003e\n\u003c/p\u003e\n\n[中文版](./README_cn.md)\n\nFlagAttention is a project for memory-efficient attention operators implemented in the [Triton language](https://github.com/openai/triton). Motivated by the need for non-standard attention operators in language modeling, it starts as an extension of multi-head attention.\n\nIt saves memory footprint and traffic like [FlashAttention](https://arxiv.org/abs/2205.14135) and [FlashAttention v2](https://tridao.me/publications/flash2/flash2.pdf). Implemented in the Triton language, it is easier to understand and modify. The original implementation of FlashAttention in CUDA([flash-attention](https://github.com/Dao-AILab/flash-attention)) provides a good example of how to design an algorithm that takes different levels of memory into account. By tiling and re-computation, FlashAttention avoids materializing the attention scores, whose capacity is proportional to the square of the sequence length. However, custom transformation to the attention scores is not possible when using FlashAttention, unless it is supported by FlashAttention out-of-the-box.\nWhile extending FlashAttention requires proficiency in CUDA programming, FlagAttention implemented in the Triton language is easier to modify.\n\nFlagAttention now offers two operators.\n\n1. **flash_attention**: FlashAttention implemented in the Triton language.\n2. **piecewise_attention**. Currently employed for NLPE(Non-Linearized position embedding) in both training and inference of the [Aquila-2-34B](https://github.com/FlagAI-Open/Aquila2) model.\n\nWhen further customization is required, FlagAttention servers as an example.\n\n## Changelog\n\n### v0.1\n\nAdd piecewise_attention \u0026 flash_attention.\n\n### v0.2\n\nOptimization of operators.\n1. applying mask only when needed.\n2. use a separate kernel to compute the gradien of q to avoid atomic RMW to global memory.\n\n\n## Requirements\n\nFlagAttention requires Pytorch and Triton. To use the new features of Triton, a nightly release is recommended.\n\n\n```sh\n# install a nightly release of Triton\npip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly\n```\n\nFlagAttention requires Ampere Nvidia GPUs(e.g. A100, RTX-3090, ...) and CUDA Toolkit 11.6 or above. Other GPUs may work but have not been tested yet.\n\n## Installation\n\nFlagAttention can be installed in either way below.\n\n1. Editable Installation. Changes to the code in the local source tree are effective without re-installation.\n2. Build a distribution and then install. Only the package is installed.\n\n### Editable Installation\n\nEditable installation with pip.\n\n```sh\ngit clone https://github.com/FlagOpen/FlagAttention \u0026\u0026 cd FlagAttention\npip install -e .\n```\n\n### Build a Distribution \u0026 Install\n\nFollowing modern Python packaging convention(PEP-517), FlagAttention is configured by [`pyproject.toml`](https://pip.pypa.io/en/stable/reference/build-system/pyproject-toml/), and no `setup.py` is provided. To build a distribution, either a source distribution or a binary distribution, python package `build` is recommended.\n\nFirst, install `build` package via pip.\n\n```sh\npip install build\n```\n\nThen build the package.\n\n```sh\ngit clone https://github.com/FlagOpen/FlagAttention \u0026\u0026 cd FlagAttention\n# to build in `no-isolation` mode requires installing build requirements manually\npip install -U setuptools setuptools-scm\npython -m build --no-isolation\n```\n\nThe built package is in `dist/` for installation.\n\n```sh\npip install dist/flag_attn-xxx.whl\n```\n\n## Usage\n\nFlagAttention provides customized operators for attention. When an operator is equivalent to a torch function, it can be used as a drop-in replacement.\n\n## Run the Tests\n\nA recent version of `pytest`(\u003e=7.1.0) is required to run the tests in `tests/`. Operators in `FlagAttention` are tested against [reference implementations](src/flag_attn/testing) in Pytorch provided by `flag_attn.testing`, both for the forward and backward operators. For operators with support for inputs of `float16` or `bfloat16`, three different implementations are included for numerical accuracy testing.\n\n1. **Reference Implementation in Pytorch**: This implementation upcasts the inputs to `float32` and performs the computations in `float32` all the way through before casting the outputs to `float16` or `bfloat16`.\n2. **Triton Implementation**: The Triton implementation uses `float16` or `bfloat16` for MMA(matrix multiplication accumulation) inputs and `float32` for MMA outputs and other computations.\n3. **Pytorch Implementation**: This implementation mirrors the computations in the reference implementation, except that the precision is the same as the Triton implementation.\n\nThe tests for numerical accuracy enforce that the maximum difference between the Triton implementation and reference implementation is not greater than twice the maximanum difference between the Pytorch implementation and reference implementation.\n\n```sh\npytest .\n```\n\n## Run the Benchmark\n\nBenchmarks are included to quantify the achieved `TFLOP/s`, which serves as a metric of speed operators. The calculation of FLOPs for an operator considers only the matmul operation. The resulting FLOPs are then divided by the median runtime to determine the achieved FLOPs/s.\n\nThe benchmarking process involves comparing the Triton implementations with counterparts in Pytorch. When the input size is large, resulting in memory exhaustion in the Pytorch implementation, the FLOP/s is considered zero.\n\n```sh\ncd benchmarks/\npython flash_benchmark.py\npython piecewise_benchmark.py\n```\n\n## Operators\n\n### flash_attention\n\nThe implementation of FlashAttention in the Triton language. The interface is.\n\n```python\nflash_attention(q, k, v, causal=False, sm_scale=None, return_log_normalizer=False, return_total_attention=False)\n```\n\nIn addition to the attention outputs, it can return some extra outputs dependes on `return_log_normalizer` and `return_total_attention`.\n\n1. log_normalizer: shape (batch_size, num_heads, seqlen_q). The log normalizer of the softmax inside attention operation.\n2. total_attention: shape (batch_size, num_heads, seqlen_k). The sum of attention weights along q's sequence axis.\n\n### piecewise_attention\n\nThe first extension to FlashAttention is [piecewise_attention](src/flag_attn/piecewise.py). This operator enhances FlashAttention by using two `q`'s and two `k`'s to calculate the attention scores(S) before applying softmax to obtain the attention weights(P).\n\nThe rationale behind this design is rooted in the observations that a transformer with rotary position embedding struggles with predicting sequences longer than the maximum sequence length it is trained on. Pairs of `(q, k)` yield unexpectedly high attention scores when the distance exceeds the maximum sequence length in the training set.\n\nTo address this issue, BAAI proposes NLPE(Non-Linearized Position Embedding), which applies two different position embeddings to `q` and `k` based on whether the distance between `q` and `k` exceeds a pre-defined threshold, producing `q1, q2` and `k1, k2`. Then the attention score is computed as the dot product of `q1, k1` or `q2, k2` depending on the distance between `q` and `k`.\n\n\n\nThe interface is shown below.\n\n![piecewise_attention_interface](./assets/piecewise_attention_interface.png)\n\n```python\npiecewise_attention(q1, k1, q2, k2, v, dist_threshold, causal=False, sm_scale=None)\n```\n\nIt splices two attention scores(S) in the forward computation and splits the gradient of S in the backward computation.\n\n![piecewise attention](assets/piecewise_attention.png)\n\n#### Usage\n\n```python\n# piecewise_attention\nimport torch\nfrom flag_attn import piecewise_attention\n\nB, H, T, D = 2, 16, 8192, 128\ndist_threshold = T // 2\n\nq1 = torch.randn((B, H, T, D), dtype=torch.float16, device=\"cuda:0\").requires_grad_()\nq2 = torch.randn((B, H, T, D), dtype=torch.float16, device=\"cuda:0\").requires_grad_()\nk1 = torch.randn((B, H, T, D), dtype=torch.float16, device=\"cuda:0\").requires_grad_()\nk2 = torch.randn((B, H, T, D), dtype=torch.float16, device=\"cuda:0\").requires_grad_()\nv = torch.randn((B, H, T, D), dtype=torch.float16, device=\"cuda:0\").requires_grad_()\no = piecewise_attention(q1, k1, q2, k2, v, dist_threshold, causal=True)\nprint(o)\n\ngo = torch.randn((B, H, T, D), dtype=torch.float16, device=\"cuda:0\")\ngq1, gk1, gq2, gk2, gv = torch.autograd.grad(\n    o, (q1, k1, q2, k2, v), go\n)\nprint(gq1)\n```\n\n```python\n# flash_attention\nimport torch\nfrom flag_attn import flash_attention\n\nB, H, T, D = 2, 16, 8192, 128\n\nq = torch.randn((B, H, T, D), dtype=torch.float16, device=\"cuda:0\").requires_grad_()\nk = torch.randn((B, H, T, D), dtype=torch.float16, device=\"cuda:0\").requires_grad_()\nv = torch.randn((B, H, T, D), dtype=torch.float16, device=\"cuda:0\").requires_grad_()\no = flash_attention(q, k, v, causal=True)\nprint(o)\n\ngo = torch.randn((B, H, T, D), dtype=torch.float16, device=\"cuda:0\")\ngq, gk, gv = torch.autograd.grad(\n    o, (q, k, v), go\n)\nprint(gq)\n```\n\n#### Performance\n\nBenchmark is performed under such conditions.\n\n1. seqlen in `[512, 1k, 2k, 4k, 16k, 32k]`;\n2. batch size: `32k / seqlen`;\n3. headdim in`[64, 128]`；\n4. num_heads: `2048 / headdim`.\n\n##### flash_attention\n\nThe performance of flash_attention with causal masking is shown below.\n\n![headdim64](./assets/v0.2/flash_attention_d64.png)\n\n![headdim128](./assets/v0.2/flash_attention.png)\n\nThe forward operator runs as fast as, and in some cases, faster than FlashAttention(CUDA), but the backward operator is generally slower than FlashAttention. We first follow the paper and update the gradient of Q with atomic addition in the backward operator, which runs extremely slowly. Then we split the backward operator into two kernels, one to compute the gradient of k and v, the other to compute the gradient of q. This alternation avoids atomic additions but introduces more re-computation. Although this strategy yields a 4x to 5x speedup in the backward operator, it is still slower than FlashAttention(CUDA).\n\nThe same split-kernel trick is also applied to `piecewise_attention` for efficiency.\n\n##### piecewise_attention\n\nThe performance of piecewise_attention has improved compared to that in v0.1. In the case where the head dim is 128 and causal masking is applied, the forward and backward operator is faster than that in v0.1 by 36% and 9%, respectively.\n\n![piecewise_attention](./assets/v0.2/piecewise_attention.png)\n\n#### Features\n\n- support for [Nvidia](https://www.nvidia.com/) Ampere GPU(Tested on RTX-3090 and A100)；\n- support for [Iluvatar CoreX](https://www.iluvatar.com/) GPU(Tested on Iluvatar CoreX MR-V100)；\n- datatype support, `float16` and `bfloat16` for Ampere Nvidia GPUs;\n- support causal and non-causal modes;\n- support forward \u0026 backward modes;\n- the sequence length of k/v can be different from that of q;\n- support computation of total attention of each `k` gets from all `q`'s;\n- supports returning accumulative attention of each keys.\n- supports [MQA](https://arxiv.org/abs/1911.02150) and [GQA](https://arxiv.org/pdf/2305.13245).\n- supports dropout of attention weights.\n\n#### Limitations\n\n- `headdim` should be in `[16, 32, 64, 128]`.\n\n## TODOs\n\n1. Test on other GPUs;\n2. Test on more versions of triton；\n3. Improve performance of attention operators(especially for the backward op);\n4. Support other extensions to flash attention.\n\n## More\n\nFor more about the open source system for large models from BAAI, please with [BAAI/FlagOpen](https://flagopen.baai.ac.cn/).\n[\u003cimg src=\"./assets/logo/baai-flagopen.jpeg\"\u003e](https://flagopen.baai.ac.cn/)\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2FFlagOpen%2FFlagAttention","html_url":"https://awesome.ecosyste.ms/projects/github.com%2FFlagOpen%2FFlagAttention","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2FFlagOpen%2FFlagAttention/lists"}