{"id":28578437,"url":"https://github.com/xlite-dev/ffpa-attn","last_synced_at":"2025-06-11T01:10:18.762Z","repository":{"id":271186089,"uuid":"896023715","full_name":"xlite-dev/ffpa-attn","owner":"xlite-dev","description":"📚FFPA(Split-D): Extend FlashAttention with Split-D for large headdim, O(1) GPU SRAM complexity, 1.8x~3x↑🎉 faster than SDPA EA.","archived":false,"fork":false,"pushed_at":"2025-05-10T04:56:13.000Z","size":4418,"stargazers_count":183,"open_issues_count":3,"forks_count":8,"subscribers_count":3,"default_branch":"main","last_synced_at":"2025-06-01T15:48:14.679Z","etag":null,"topics":["attention","cuda","deepseek","deepseek-r1","deepseek-v3","flash-attention","flash-mla","fused-mla","mla","mlsys","sdpa","tensor-cores"],"latest_commit_sha":null,"homepage":"https://zhuanlan.zhihu.com/p/13975660308","language":"Cuda","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"gpl-3.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/xlite-dev.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null}},"created_at":"2024-11-29T11:47:23.000Z","updated_at":"2025-05-30T09:11:50.000Z","dependencies_parsed_at":"2025-02-04T02:30:56.548Z","dependency_job_id":"3920e874-ac9f-4dea-b7e4-61318a6c5c6c","html_url":"https://github.com/xlite-dev/ffpa-attn","commit_stats":null,"previous_names":["deftruth/faster-prefill-attention","deftruth/cuffpa-py","deftruth/ffpa-attn-mma","xlite-dev/ffpa-attn-mma","xlite-dev/ffpa-attn"],"tags_count":9,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/xlite-dev%2Fffpa-attn","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/xlite-dev%2Fffpa-attn/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/xlite-dev%2Fffpa-attn/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/xlite-dev%2Fffpa-attn/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/xlite-dev","download_url":"https://codeload.github.com/xlite-dev/ffpa-attn/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/xlite-dev%2Fffpa-attn/sbom","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":259178542,"owners_count":22817389,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["attention","cuda","deepseek","deepseek-r1","deepseek-v3","flash-attention","flash-mla","fused-mla","mla","mlsys","sdpa","tensor-cores"],"created_at":"2025-06-11T01:10:12.319Z","updated_at":"2025-06-11T01:10:18.740Z","avatar_url":"https://github.com/xlite-dev.png","language":"Cuda","readme":"\u003cdiv align=\"center\"\u003e\n  \u003cp align=\"center\"\u003e\n    \u003ch2\u003e🤖FFPA(Split-D): Yet another Faster Flash Prefill Attention with O(1)⚡️GPU SRAM complexity for large headdim🐑\u003c/h2\u003e\n    \u003ca href=\"https://zhuanlan.zhihu.com/p/13975660308\"\u003e📚FFPA(Split-D) Blog\u003c/a\u003e | \u003ca href=\"#L1-bench-l20\"\u003e 📈L20 ~1.9x↑🎉 \u003c/a\u003e | \u003ca href=\"#L1-bench-a30\"\u003e 📈A30 ~1.8x↑🎉 \u003c/a\u003e | \u003ca href=\"#L1-bench-3080\"\u003e 📈3080 ~2.9x↑🎉 \u003c/a\u003e | \u003ca href=\"#L1-bench-4090\"\u003e 📈4090 ~2.1x↑🎉 \u003c/a\u003e \u003cp\u003e\n  \u003c/p\u003e\n  \u003cimg src=https://github.com/user-attachments/assets/4abfae2d-5a26-4f73-aaa2-d1e452a4215d width=250 \u003e\n  \u003cdiv align='center'\u003e\n    \u003cimg src=https://img.shields.io/badge/Language-CUDA/Python-brightgreen.svg \u003e\n    \u003cimg src=https://img.shields.io/github/watchers/xlite-dev/ffpa-attn?color=9cc \u003e\n    \u003cimg src=https://img.shields.io/github/forks/xlite-dev/ffpa-attn.svg?style=social \u003e\n    \u003cimg src=https://img.shields.io/github/stars/xlite-dev/ffpa-attn.svg?style=social \u003e\n    \u003cimg src=https://img.shields.io/badge/Release-v0.0.2-brightgreen.svg \u003e\n    \u003cimg src=https://img.shields.io/badge/License-GPLv3.0-turquoise.svg \u003e\n \u003c/div\u003e\n\u003c/div\u003e\n\n\u003cdiv align=\"center\"\u003e\n  \u003cp align=\"center\"\u003e \u003ch2\u003e 🤖FFPA: 1.8x~3x🎉faster vs SDPA EA with or without MMA Acc F32\u003c/h2\u003e\u003c/p\u003e\n\u003c/div\u003e\n\n🤖[WIP] **FFPA**: Yet another **Faster Flash Prefill Attention** with **O(1) SRAM complexity** \u0026 **O(d/4) or O(1) register complexity** for large headdim (D \u003e 256), almost **1.8x~3x** 🎉 faster than SDPA EA with or without MMA Acc F32 on many devices: [📈L20 ~1.9x↑🎉](#L1-bench-l20), [📈A30 ~1.8x↑🎉](#L1-bench-a30), [📈3080 ~2.9x↑🎉](#L1-bench-3080), [📈4090 ~2.1x↑🎉](#L1-bench-4090). **FFPA Attention Algo: Fine-grained tiling** for large headim, **FA-2 Attention Algo: Coarse-grained tiling** for small headidm.\n\n\u003c!--\n![image](https://github.com/user-attachments/assets/b881cef6-3c49-4a2a-b390-43b328de7b10)\n![FFPA vs FA2](https://github.com/user-attachments/assets/c6cefc9a-5ef1-48ee-8c7d-68346c60bdcb)\n--\u003e\n\u003cimg width=\"1496\" alt=\"image\" src=\"https://github.com/user-attachments/assets/6b5cc7c1-50f9-42cb-a123-4bf5b4ac8d6c\" /\u003e\n\n\n💡NOTE: This project is still in its early dev stages and now provides some kernels and benchmarks for reference. More features will be added in the future. (Welcome to 🌟👆🏻star this repo to support me ~)\n\n## ©️Citations🎉🎉\n\n```BibTeX\n@misc{ffpa-attn@2025,\n  title={FFPA: Yet another Faster Flash Prefill Attention for large headdim.},\n  url={https://github.com/xlite-dev/ffpa-attn.git},\n  note={Open-source software available at https://github.com/xlite-dev/ffpa-attn.git},\n  author={xlite-dev etc},\n  year={2025}\n}\n```\n\n## 📖 Contents\n\n- [📖 Installation⚙️](#install)\n- [📖 Python Testing👇](#python-test)\n- [📖 FFPA L1~L3 Design💡](#ffpa-design)\n- [📈 FFPA L1: L20 ~1.9x↑🎉](#L1-bench-l20)\n- [📈 FFPA L1: A30 ~1.8x↑🎉](#L1-bench-a30)\n- [📈 FFPA L1: 3080 ~2.9x↑🎉](#L1-bench-3080)\n- [📈 FFPA L1: 4090 ~2.1x↑🎉](#L1-bench-4090)\n- [📖 Fully Fused MLA w/ FFPA🎉](#fused-mla)\n\n## 📖 FFPA L1~L3: FlashAttention + QKV Fine-grained Tiling at MMA level💡\n\u003cdiv id=\"ffpa-design\"\u003e\u003c/div\u003e\n\nWe have extended FlashAttention for large headdim (D \u003e 256) by implementing **Fine-grained Tiling** at the **MMA level (GEMM style)** for the Q@K^T and P@V matmul. This approach results in a constant SRAM usage of Br * 16 or Bc * 16 (Br = Bc) for Q, K, and V, leading to an overall SRAM complexity of O(2 * Br * 16) ≈ O(1) and a register complexity of O(d/4) or O(1). Consequently, this method allows us to extend headdim beyond 256 and achieve faster performance compared to SDPA with or without MMA Accumulation F32 (**1.8x~3x** 🎉 faster than SDPA EA).\n\nWe have named this new attention tiling technique **FFPA: Faster Flash Prefill Attention**. We have designed three `(L1~L3)` levels of FFPA based on SRAM and register complexity considerations. All levels will not introduce any additional VRAM requirements, ensuring that the HBM memory complexity remains same as FlashAttention. 👇\n\n- [x] 📚L1: level 1, O(2xBrx16)≈O(1) SRAM complexity, ≈O(d/4) register complexity.\n- [ ] 📚L2: level 2, O(2xBrx16)≈O(1) SRAM complexity, ≈O(1) register complexity + Q@K^T recomputation.\n- [ ] 📚L3: level 3, O(2xBrx16)≈O(1) SRAM complexity, ≈O(1) register complexity + scaling O via HBM offloading.\n\nBy leveraging this approach, we can achieve better performance than SDPA EA for very large headdim (D \u003e 256, `FA-2 not supported`). Approximate SRAM and register complexity analysis for FFPA L1~L3 level is as follows: (`d`=headdim, `C,Br,Bc`=Constant, `Br=Bc`, let O(C)≈O(1)) 👇\n\n|📚Complexity| 📚FFPA L1 |  📚FFPA L2 |  📚FFPA L3 | 📚FA-2 |\n|:---:|:---:|:---:|:---:|:---:|\n|SRAM | O(2xBrx16)≈O(1) | O(2xBrx16)≈O(1) | O(2xBrx16)≈O(1) | ≈O(3xBrxd), d↑ |\n|Register | ≈O(d/4), d↑ | O((Bc/16)x4+2C)≈O(1)|O((Bc/16)x4+2C)≈O(1)| ≈O(d/2), d↑ |\n|HBM| ≈FA2≈O(Nd), O | ≈FA2≈O(Nd), O| ≈FA2≈O(Nd), O | ≈O(Nd), O |\n|Extra HBM| ≈FA2≈O(N), m,l | ≈FA2≈O(N), m,l | ≈FA2≈O(N), m,l | ≈O(N), m,l |\n\n**📚👇Core Features🎉🎉**: I have implemented **FFPA L1~L3** using pure MMA PTX instructions, which supports many features such as Split-Q, SMEM Swizzle/Padding, QKV Multi-Stages(1~4), Tile MMAs/Warps, Mixed MMA F32/F16 Acc (Q@K^T MMA Acc F32 + P@V MMA Acc F16), Fully Shared QKV SMEM, Prefetch QKV g2s, Persist Q s2r/g2s, **Fully QKV Fine-grained Tiling(GEMM style)**, Collective Store, etc.\n\n|📚Feature |📚Feature |📚Feature |📚Feature|\n|:---:|:---:|:---:|:---:|\n|✔️Tensor Cores |✔️**MMA(m16n8k16)** |✔️Tile Block(Br, Bc) |✔️Tile MMA/Warp |\n|✔️**Split Q**(FA-2)|✔️Pack LDST(128 bits)|✔️SMEM **Swizzle/Pad** |✔️Copy Async |\n|✔️**Reg Double Buffers** |✔️QKV **Multi-Stages(1~4)** |✔️Collective Store(**Shfl**)|✔️**Prefetch QKV** g2s |\n|✔️**QKV Fine-grained Tiling**|✔️**Shared QKV** SMEM|✔️Mixed MMA Acc|✔️**Persist Q** s2r/g2s|\n\n- 📚 case: FFPA `L1` kernel template signature: [ffpa_attn_templates_L1.cuh](csrc/cuffpa/ffpa_attn_templates_L1.cuh)\n\n```CUDA\ntemplate\u003c\n  const int kHeadDim,              // Headdim, 32~1024     \n  const int kMmaAtomM,             // MMA Atom M, 16\n  const int kMmaAtomN,             // MMA Atom N, 8\n  const int kMmaAtomK,             // MMA Atom K, 16\n  const int kMmaTileSeqLenQ,       // 4, more MMA(warp), M=16*4=64, Q@K^T=[Br(M), d(K)]@[d(K),  Bc(N)]  \n  const int kMmaTileSeqLenK,       // 1, more MMA(warp), N=8*1 =8,  Q@K^T=[Br(M), d(K)]@[d(K),  Bc(N)]    \n  const int kMmaTileSeqLenP,       // 4, more MMA(warp), M=16*4=64, P@V  =[Br(M),Bc(K)]@[Bc(K), d(N) ]\n  const int kMmaTileHeadDimV,      // 1, more MMA(warp), N=8*1 =8,  P@V  =[Br(M),Bc(K)]@[Bc(K), d(N) ]       \n  const int kWarpTileSeqLenQ,      // 1, more values, M, Br=64*1=64, matmul M \n  const int kWarpTileSeqLenK,      // 8, more values, N, Bc=8*8 =64, matmul N\n  const int kWarpTileSeqLenP,      // 1, more values, M, Br=64*1=64, matmul M\n  const int kWarpTileHeadDimV,     // 8, more values, N, d=8*(1|2|3|4|...)=8|...|32|64|96|128|...\n  const int kMmaAccFloat32QK,      // 0/1, Q@K^T, 0 MMA Acc with fp16, 1 MMA Acc with fp32.\n  const int kMmaAccFloat32PV,      // 0/1, P@V, 0 MMA Acc with fp16, 1 MMA Acc with fp32.\n  const int kOStorageAccFloat32,   // 0/1, MMA Acc always be f32/f16, but O storage can be fp32 or half.\n  const int kPrefetchQK,           // Prefetch QK at the Appropriate Time Point. \n  const int kPrefetchPV,           // Prefetch V at the Appropriate Time Point. \n  const int kShareSmemQKV,         // QKV share the same shared memory, reuse QK smem for V.\n  const int kPersistQs2r,          // Persist load Q s2r for headdim  \u003c 512, more registers, but still keep O(1) SRAM.\n  const int kPersistQg2s,          // Persist load Q g2s for headdim \u003c= 320, more SRAM, but still keep register usage.\n  const int kRegPipeKV,            // Registers Ping pong double buffers for ldmatrix s2r \u0026 mma computation overlapping.\n  const int kStageQK,              // \u003c= 4, may apply different multi stages policy for QK and V (\u003c=4)\n  const int kStagePV,              // \u003c= 4, may apply different multi stages policy for QK and V (\u003c=4)\n  const int kPadQ,                 // Pad Q/K/V 0,8; 0 -\u003e smem swizzle, \u003e 0 -\u003e padding\n  const int kPadK,                 // Pad Q/K/V 0,8; 0 -\u003e smem swizzle, \u003e 0 -\u003e padding\n  const int kPadV                  // Pad Q/K/V 0,8; 0 -\u003e smem swizzle, \u003e 0 -\u003e padding\n\u003e __global__ void // Q, K, V, O -\u003e [B, H, N, D]\n// FFPA Attention Algo: Fine-grained tiling at MMA level for large headdim (d\u003e=256), \n// which can achieve 1.8x~3x🎉 faster than SDPA EA with or without MMA Acc F32.\nffpa_mma_stages_split_q_L1_large_d_template(half* Q, half* K, half* V, half* O, ...); \n// FA-2 Attention Algo: Coarse-grained tiling at Attention level for small headdim (d\u003c256), \n// which can achieve 95%-105%🎉 performance as SDPA FA-2 BE with MMA Acc F32 for N\u003c=4096, \n// and achieve almost 1.2x~1.4x🎉 faster than SDPA FA-2 via Mixed MMA Acc(Q@K^T F32 + \n// P@V F16) for all range N.\nffpa_mma_stages_split_q_L1_small_d_template(half* Q, half* K, half* V, half* O, ...); \n```\n\n## 📖 Prerequisites\n\u003cdiv id=\"prerequisites\"\u003e\u003c/div\u003e\n\n- Python \u003e= 3.10\n- PyTorch \u003e= 2.4.0, CUDA \u003e= 12.4\n- flash-attention \u003e= 2.6.3 (for test)\n- Recommended: PyTorch 2.5.1, CUDA 12.5\n- Docker: nvcr.io/nvidia/pytorch:24.10-py3\n\n## 📖 Installation\n\n\u003cdiv id=\"install\"\u003e\u003c/div\u003e\n\nThe FFPA implemented in this repo can be install as a python library, namely, `ffpa-attn` library (optional).\n```bash\ngit clone https://github.com/xlite-dev/ffpa-attn.git\n# clone, then, run bash .dev/install.sh directly or run commands:\npython3 setup.py bdist_wheel \u0026\u0026 cd dist \u0026\u0026 python3 -m pip install *.whl # pip uninstall ffpa-attn -y\n```\n\n## 📖 FFPA L1 (Level 1): Benchmark 🎉🎉\n\n\u003cdiv id=\"L1-bench-l20\"\u003e\u003c/div\u003e\n\nL1: level 1, O(2xBrx16)≈O(1) SRAM complexity, O(d/4) register complexity, the same GPU HBM memory complexity as FlashAttention. B=1, H=48, N=8192, **D=320-1024(FA2 not supported 👀)**. (Notes, `*`=MMA Acc F32, `^`=MMA Acc F16, Softmax Acc dtype is always be F32, T=TFLOPS, 👇Benchmark)\n\n- 📚 NVIDIA L20 (`*`=MMA Acc F32, `^`=MMA Acc F16, `T`=TFLOPS, **~1.8x↑🎉**)\n\n|Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|\n|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|\n|SDPA EA|56T|63T|58T|58T|55T|56T|54T|55T|54T|55T|54T|56T|\n|FFPA L1*|102T|102T|103T|104T|103T|95T|95T|95T|95T|96T|95T|94T|\n|Speedup|1.82x|1.62x|1.78x|1.79x|1.87x|1.7x|1.76x|1.73x|1.76x|1.75x|1.76x|1.68x|\n|FFPA L1^|104T|103T|103T|102T|104T|103T|102T|94T|94T|94T|100T|100T|\n|Speedup|1.86x|1.63x|1.78x|1.76x|1.89x|1.84x|1.89x|1.71x|1.74x|1.71x|1.85x|1.79x|\n\n- 📚 NVIDIA L20 (`*`=MMA Acc: QK F32 + PV F16, `^`=MMA Acc F16, `T`=TFLOPS, **~1.9x↑🎉**)\n\n|Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|\n|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|\n|SDPA EA|56T|64T|58T|58T|55T|56T|54T|55T|54T|55T|54T|56T|\n|FFPA L1*|105T|102T|104T|103T|105T|95T|95T|94T|94T|94T|102T|101T|\n|Speedup|1.88x|1.59x|1.79x|1.78x|1.91x|1.7x|1.76x|1.71x|1.74x|1.71x|1.89x|1.8x|\n|FFPA L1^|104T|103T|103T|102T|103T|103T|102T|94T|94T|94T|100T|100T|\n|Speedup|1.86x|1.61x|1.78x|1.76x|1.87x|1.84x|1.89x|1.71x|1.74x|1.71x|1.85x|1.79x|\n\n\u003cdiv align='left'\u003e\n  \u003cimg src='https://github.com/user-attachments/assets/a4927108-3f97-4209-9b80-bb31ad271e04' width=\"411px\"\u003e\n  \u003cimg src='https://github.com/user-attachments/assets/eeb9943f-919d-45d8-a8a6-e0f8874f4bcd' width=\"411px\"\u003e\n\u003c/div\u003e \n\n\u003cdiv id=\"L1-bench-a30\"\u003e\u003c/div\u003e\n\n- 📚 NVIDIA A30 (`*`=MMA Acc F32, `^`=MMA Acc F16, `T`=TFLOPS, **~1.8x↑🎉**)\n\n|Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|\n|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|\n|SDPA EA|25T|25T|24T|24T|24T|24T|23T|22T|22T|22T|22T|18T|\n|FFPA L1*|45T|44T|44T|43T|43T|38T|37T|37T|37T|36T|33T|32T|\n|Speedup|1.8x|1.76x|1.83x|1.79x|1.79x|1.58x|1.61x|1.68x|1.68x|1.64x|1.5x|1.78x|\n|FFPA L1^|48T|46T|45T|43T|44T|44T|44T|38T|37T|36T|40T|34T|\n|Speedup|1.92x|1.84x|1.88x|1.79x|1.83x|1.83x|1.91x|1.73x|1.68x|1.64x|1.82x|1.89x|\n\n- 📚 NVIDIA A30 (`*`=MMA Acc: QK F32 + PV F16, `^`=MMA Acc F16, `T`=TFLOPS, **~1.9x↑🎉**)\n\n|Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|\n|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|\n|SDPA EA|25T|25T|24T|24T|24T|24T|23T|22T|22T|22T|22T|18T|\n|FFPA L1*|48T|46T|46T|43T|44T|38T|38T|38T|37T|36T|40T|34T|\n|Speedup|1.92x|1.84x|1.92x|1.79x|1.83x|1.58x|1.65x|1.73x|1.68x|1.64x|1.82x|1.89x|\n|FFPA L1^|48T|46T|45T|43T|44T|44T|44T|38T|37T|36T|39T|34T|\n|Speedup|1.92x|1.84x|1.88x|1.79x|1.83x|1.83x|1.91x|1.73x|1.68x|1.64x|1.77x|1.89x|\n\n\u003cdiv align='left'\u003e\n  \u003cimg src='https://github.com/user-attachments/assets/7e323005-4445-41af-8e94-6efb62ed2b77' width=\"411px\"\u003e\n  \u003cimg src='https://github.com/user-attachments/assets/e314649e-82b5-414d-85c9-8b6fbf260138' width=\"411px\"\u003e\n\u003c/div\u003e \n\n\u003cdiv id=\"L1-bench-3080\"\u003e\u003c/div\u003e\n\n- 📚 NVIDIA RTX 3080 Laptop (`*`=MMA Acc F32, `^`=MMA Acc F16, `T`=TFLOPS, **~2.5x↑🎉**)\n\n|Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|\n|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|\n|SDPA EA|13T|16T|11T|16T|15T|15T|15T|15T|14T|14T|14T|14T|\n|FFPA L1*|33T|31T|30T|30T|30T|27T|27T|26T|26T|26T|26T|25T|\n|Speedup|2.54x|1.94x|2.73x|1.88x|2.0x|1.8x|1.8x|1.73x|1.86x|1.86x|1.86x|1.79x|\n|FFPA L1^|43T|41T|39T|39T|39T|39T|39T|36T|34T|33T|31T|33T|\n|Speedup|3.31x|2.56x|3.55x|2.44x|2.6x|2.6x|2.6x|2.4x|2.43x|2.36x|2.21x|2.36x|\n\n- 📚 NVIDIA RTX 3080 Laptop (`*`=MMA Acc: QK F32 + PV F16, `^`=MMA Acc F16, `T`=TFLOPS, **~2.9x↑🎉**)\n\n|Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|\n|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|\n|SDPA EA|13T|15T|12T|15T|14T|15T|14T|14T|14T|14T|14T|14T|\n|FFPA L1*|38T|36T|34T|35T|34T|31T|32T|31T|30T|28T|27T|27T|\n|Speedup|2.92x|2.4x|2.83x|2.33x|2.43x|2.07x|2.29x|2.21x|2.14x|2.0x|1.93x|1.93x|\n|FFPA L1^|44T|41T|39T|39T|38T|39T|39T|36T|34T|32T|31T|33T|\n|Speedup|3.38x|2.73x|3.25x|2.6x|2.71x|2.6x|2.79x|2.57x|2.43x|2.29x|2.21x|2.36x|\n\n\u003cdiv align='left'\u003e\n  \u003cimg src='https://github.com/user-attachments/assets/d157cd69-4444-4735-a691-edaaff408137' width=\"411px\"\u003e\n  \u003cimg src='https://github.com/user-attachments/assets/3ce47627-e79d-40ee-b753-bdd235603b7d' width=\"411px\"\u003e\n\u003c/div\u003e \n\n\u003cdiv id=\"L1-bench-4090\"\u003e\u003c/div\u003e\n\n- 📚 NVIDIA RTX 4090 (`*`=MMA Acc F32, `^`=MMA Acc F16, `T`=TFLOPS, **~1.8x↑🎉**)\n\n|Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|\n|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|\n|SDPA EA|81T|94T|85T|85T|79T|81T|79T|80T|79T|80T|78T|78T|\n|FFPA L1*|149T|150T|150T|150T|150T|140T|140T|140T|139T|139T|137T|134T|\n|Speedup|1.84x|1.6x|1.76x|1.76x|1.9x|1.73x|1.77x|1.75x|1.76x|1.74x|1.76x|1.72x|\n|FFPA L1^|194T|194T|189T|191T|197T|188T|184T|180T|177T|172T|171T|171T|\n|Speedup|2.4x|2.06x|2.22x|2.25x|2.49x|2.32x|2.33x|2.25x|2.24x|2.15x|2.19x|2.19x|\n\n- 📚 NVIDIA RTX 4090 (`*`=MMA Acc: QK F32 + PV F16, `^`=MMA Acc F16, `T`=TFLOPS, **~2.1x↑🎉**)\n\n|Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|\n|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|\n|SDPA EA|82T|92T|85T|84T|78T|81T|79T|80T|78T|79T|77T|78T|\n|FFPA L1*|176T|170T|171T|171T|171T|161T|160T|161T|160T|158T|165T|164T|\n|Speedup|2.15x|1.85x|2.01x|2.04x|2.19x|1.99x|2.03x|2.01x|2.05x|2.0x|2.14x|2.1x|\n|FFPA L1^|200T|191T|189T|191T|188T|188T|186T|179T|175T|173T|172T|170T|\n|Speedup|2.44x|2.08x|2.22x|2.27x|2.41x|2.32x|2.35x|2.24x|2.24x|2.19x|2.23x|2.18x|\n\n\u003cdiv align='left'\u003e\n  \u003cimg src='https://github.com/user-attachments/assets/447e2937-f7c8-47c8-8550-8c0c71b910e6' width=\"411px\"\u003e\n  \u003cimg src='https://github.com/user-attachments/assets/65a8d564-8fa7-4d66-86b9-e238feb86143' width=\"411px\"\u003e\n\u003c/div\u003e \n\n## 📖 Python Testing\n\u003cdiv id=\"python-test\"\u003e\u003c/div\u003e\n\n👇You can test many custom FFPA kernels via Python and figure out the difference in their performance. The `--gen-bench` and `--plot` options help you generate a benchmark table in Markdown style and speedup bar plots on your device. Contributions of your benchmark tables and plots are welcome via a PR 🎉🎉.\n\n- 📚 case: B=1, H=48, N=8192, D=320(`FA2 not supported`)\n```python\n# You can test on many devices, such as Volta, Ampere, Ada, Hopper, ...\ncd tests \u0026\u0026 python3 test_ffpa_attn.py --B 1 --H 48 --N 8192 --show-all --D 320\n---------------------------------------B=1, H=48, N=8192, D=320, Warmup: 1, Iters: 5--------------------\n                   (sdpa): ['-0.02380371'], time:73.66518ms, TFLOPS:56.19 (+0.00 %)(~1.00x)\n (ffpa+acc+f32+L1+stage1): ['-0.02378845'], time:52.87361ms, TFLOPS:78.28 (+39.32%)(~1.39x)\n (ffpa+acc+f32+L1+stage2): ['-0.02378845'], time:40.84062ms, TFLOPS:101.35(+29.46%)(~1.80x)\n (ffpa+acc+f32+L1+stage3): ['-0.02378845'], time:40.49534ms, TFLOPS:102.21(+0.85 %)(~1.82x)\n (ffpa+acc+f32+L1+stage4): ['-0.02378845'], time:40.88177ms, TFLOPS:101.25(+0.00 %)(~1.80x)\n (ffpa+acc+f16+L1+stage1): ['-0.02378845'], time:53.43298ms, TFLOPS:77.46 (+0.00 %)(~1.38x)\n (ffpa+acc+f16+L1+stage2): ['-0.02378845'], time:39.76068ms, TFLOPS:104.10(+1.85 %)(~1.85x)\n (ffpa+acc+f16+L1+stage3): ['-0.02378845'], time:39.54901ms, TFLOPS:104.66(+0.54 %)(~1.86x)\n (ffpa+acc+f16+L1+stage4): ['-0.02378845'], time:41.06554ms, TFLOPS:100.79(+0.00 %)(~1.79x)\n--------------------------------------------------------------------------------------------------------\n```\n- 📚 case: Generate benchmark table and speedup bar plots on Your device.\n```bash\ncd tests \u0026\u0026 pip install matplotlib \u0026\u0026 python3 test_ffpa_attn.py --gen-bench --show-all --plot\n```\n- 📚 case: Compare small headdim (d\u003c256, e.g 64), FFPA-L1 vs SDPA FA-2 BE.  \n```python\n# Enable ffpa-attn small d kernel which using coarse-grained tiling method.\nexport ENABLE_FFPA_PERSIST_Q_G2S=1 \u0026\u0026 export ENABLE_FFPA_PERSIST_KV_G2S=1 \ncd tests \u0026\u0026 python3 test_ffpa_attn.py --B 1 --H 32 --N 1024 --check --show-all --D 64 # NVIDIA L20\n---------------------------------------B=1, H=32, N=1024, D=64, Warmup: 1, Iters: 5--------------------\n                   (sdpa): ['0.00802612'], time:0.148057ms, TFLOPS:59.14 (+0.00 %)(~1.00x)\n (ffpa+acc+f32+L1+stage1): ['0.00803375'], time:0.103807ms, TFLOPS:84.34 (+42.63%)(~1.43x)\n (ffpa+acc+f32+L1+stage2): ['0.00803375'], time:0.102233ms, TFLOPS:85.64 (+1.54 %)(~1.45x)\n (ffpa+acc+f32+L1+stage3): ['0.00803375'], time:0.102519ms, TFLOPS:85.40 (+0.00 %)(~1.44x)\n (ffpa+acc+f32+L1+stage4): ['0.00803375'], time:0.102043ms, TFLOPS:85.80 (+0.19 %)(~1.45x)\n (ffpa+acc+f16+L1+stage1): ['0.00795746'], time:0.104713ms, TFLOPS:83.61 (+0.00 %)(~1.41x)\n (ffpa+acc+f16+L1+stage2): ['0.00795746'], time:0.102949ms, TFLOPS:85.05 (+0.00 %)(~1.44x)\n (ffpa+acc+f16+L1+stage3): ['0.00795746'], time:0.108957ms, TFLOPS:80.36 (+0.00 %)(~1.36x)\n (ffpa+acc+f16+L1+stage4): ['0.00795746'], time:0.103282ms, TFLOPS:84.77 (+0.00 %)(~1.43x)\n--------------------------------------------------------------------------------------------------------\ncd tests \u0026\u0026 python3 test_ffpa_attn.py --B 1 --H 32 --N 4096 --check --show-all --D 64 # NVIDIA L20\n-------------------------B=1, H=32, N=4096, D=64, Warmup: 1, Iters: 5-----------------------------------\n                   (sdpa): ['0.01959229'], time:1.397752ms, TFLOPS:100.24(+0.00 %)(~1.00x)\n (ffpa+acc+f32+L1+stage1): ['0.01959229'], time:1.368856ms, TFLOPS:102.36(+2.11 %)(~1.02x)\n (ffpa+acc+f32+L1+stage2): ['0.01959229'], time:1.367807ms, TFLOPS:102.44(+0.08 %)(~1.02x)\n (ffpa+acc+f32+L1+stage3): ['0.01959229'], time:1.367855ms, TFLOPS:102.43(+0.00 %)(~1.02x)\n (ffpa+acc+f32+L1+stage4): ['0.01959229'], time:1.368045ms, TFLOPS:102.42(+0.00 %)(~1.02x)\n (ffpa+acc+f16+L1+stage1): ['0.01957703'], time:1.389312ms, TFLOPS:100.85(+0.00 %)(~1.01x)\n (ffpa+acc+f16+L1+stage2): ['0.01957703'], time:1.388311ms, TFLOPS:100.92(+0.00 %)(~1.01x)\n (ffpa+acc+f16+L1+stage3): ['0.01957703'], time:1.386976ms, TFLOPS:101.02(+0.00 %)(~1.01x)\n (ffpa+acc+f16+L1+stage4): ['0.01957703'], time:1.387834ms, TFLOPS:100.96(+0.00 %)(~1.01x)\n--------------------------------------------------------------------------------------------------------\n```\n\n💡NOTE: Please check all configurable environment variables in [env.py](./env.py).\n\n## 📖 Fully Fused MLA with FFPA 🎉\n\n\u003cdiv id=\"fused-mla\"\u003e\u003c/div\u003e\n\nExtending the support of FA for large headdim is meaningful in the context of **DeepSeek MLA**. For example, when FA supports headdim values greater than 512, we can achieve fully Fused MLA into a single CUDA kernel, after W_UK/W_UV are absorbed into W_Q/W_O (resulting in C_kv/C_q with `dc/dc' \u003e= 512`). TODO list👇:\n\n- [ ] 📚Fully Fused MLA into a single CUDA kernel using **FFPA** Algo and Tensor Cores.\n\n## ©️License\n\n\u003cdiv id=\"License\"\u003e\u003c/div\u003e\n\nGNU General Public License v3.0\n\n## 🎉Contribute\n\n\u003cdiv id=\"Contribute\"\u003e\u003c/div\u003e\n\nHow to contribute? Wecome to star⭐️ this repo to support me👆🏻 ~\n\n\u003cdiv align='center'\u003e\n\u003ca href=\"https://star-history.com/#xlite-dev/ffpa-attn\u0026Date\"\u003e\n \u003cpicture\u003e\n   \u003csource media=\"(prefers-color-scheme: dark)\" srcset=\"https://api.star-history.com/svg?repos=xlite-dev/ffpa-attn\u0026type=Date\u0026theme=dark\" /\u003e\n   \u003csource media=\"(prefers-color-scheme: light)\" srcset=\"https://api.star-history.com/svg?repos=xlite-dev/ffpa-attn\u0026type=Date\" /\u003e\n   \u003cimg img width=450 height=300 alt=\"Star History Chart\" src=\"https://api.star-history.com/svg?repos=xlite-dev/ffpa-attn\u0026type=Date\" /\u003e\n \u003c/picture\u003e\n\u003c/a\u003e\n\u003c/div\u003e\n\n## 📖 References\n\u003cdiv id=\"ref\"\u003e\u003c/div\u003e\n\n- [flash-attention](https://github.com/Dao-AILab/flash-attention)\n- [LeetCUDA](https://github.com/xlite-dev/LeetCUDA)\n- [flashinfer](https://github.com/flashinfer-ai/flashinfer)\n","funding_links":[],"categories":["📖Contents"],"sub_categories":["📖IO/FLOPs-Aware/Sparse Attention ([©️back👆🏻](#paperlist))"],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fxlite-dev%2Fffpa-attn","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fxlite-dev%2Fffpa-attn","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fxlite-dev%2Fffpa-attn/lists"}