Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/DefTruth/cuffpa-py
π[WIP] FFPA: Yet another Faster Flash Prefill Attention with O(1)πGPU SRAM complexity for headdim > 256, ~1.5xπfaster than SDPA EA.
https://github.com/DefTruth/cuffpa-py
attention cuda flash-attention mlsys sdpa tensor-cores
Last synced: 19 days ago
JSON representation
π[WIP] FFPA: Yet another Faster Flash Prefill Attention with O(1)πGPU SRAM complexity for headdim > 256, ~1.5xπfaster than SDPA EA.
- Host: GitHub
- URL: https://github.com/DefTruth/cuffpa-py
- Owner: DefTruth
- License: gpl-3.0
- Created: 2024-11-29T11:47:23.000Z (about 2 months ago)
- Default Branch: main
- Last Pushed: 2025-01-08T03:19:49.000Z (19 days ago)
- Last Synced: 2025-01-08T03:25:58.241Z (19 days ago)
- Topics: attention, cuda, flash-attention, mlsys, sdpa, tensor-cores
- Language: Cuda
- Homepage:
- Size: 210 KB
- Stars: 28
- Watchers: 1
- Forks: 1
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
- Awesome-LLM-Inference - **FFPA** - py]](https://github.com/DefTruth/cuffpa-py) ![](https://img.shields.io/github/stars/DefTruth/cuffpa-py)|βοΈβοΈ | (πContents / πIO/FLOPs-Aware/Sparse Attention ([Β©οΈbackππ»](#paperlist)))
README
π€ [WIP] **FFPA**: Yet antother **Faster Flash Prefill Attention** with **O(1) SRAM complexity** & **O(d/4) or O(1) register complexity** for large headdim (D > 256), almost **>1.5x** π faster than SDPA EA with or without MMA Accumulation F32 on many devices, such as NVIDIA L20, 4090, 3080 Laptop (Experimental π~). The FFPA kernels are modified from my repo π[CUDA-Learn-Notes](https://github.com/DefTruth/CUDA-Learn-Notes/tree/main/kernels/flash-attn) ![](https://img.shields.io/github/stars/DefTruth/CUDA-Learn-Notes.svg?style=social).
NOTE: This project is still in its early dev stages and now provides a few experimental kernels and benchmarks for reference. More features will be added in the future. Welcome to πππ»star this repo to support me ~ ππ
## Β©οΈCitationsππ
```BibTeX
@misc{cuffpa-py@2025,
title={FFPA: Yet another Faster Flash Prefill Attention for large headdim.},
url={https://github.com/DefTruth/cuffpa-py},
note={Open-source software available at https://github.com/DefTruth/cuffpa-py},
author={DefTruth etc},
year={2025}
}
```## π Contents
- [π Prerequisites](#prerequisites)
- [π Installation](#install)
- [π FFPA L1~L3 Design](#ffpa-design)
- [π FFPA L1 Benchmark](#L1-bench)
- [π FFPA L2 Benchmark](#L1-bench)
- [π FFPA L3 Benchmark](#L1-bench)
- [π Python Testing](#python-test)
- [π References](#ref)## π FFPA L1~L3: FlashAttention + QKV Fine-grained Tiling at MMA level ποΈ
We have extended FlashAttention for large headdim (D > 256) by implementing **Fine-grained Tiling** at the **MMA level (GEMM style)** for the Q@K^T and P@V matmul. This approach results in a constant SRAM usage of Br * 16 or Bc * 16 for Q, K, and V, leading to an overall SRAM complexity of O(Br * 16) β O(1) and a register complexity of O(d/4) or O(1). Consequently, this method allows us to extend **headdim > 256** and achieve faster performance compared to SDPA with or without MMA Accumulation F32 (almost **>1.5x** π faster than SDPA EA).
We have named this new attention tiling technique **FFPA: Faster Flash Prefill Attention**. We have designed three `(L1~L3)` levels of FFPA based on SRAM and register complexity considerations. All levels will not introduce any additional VRAM requirements, ensuring that the HBM memory complexity remains same as FlashAttention. π
- [x] πL1: level 1, O(Brx16)βO(1) SRAM complexity, βO(d/4) register complexity.
- [ ] πL2: level 2, O(Brx16)βO(1) SRAM complexity, βO(1) register complexity + Q@K^T recomputation.
- [ ] πL3: level 3, O(Brx16)βO(1) SRAM complexity, βO(1) register complexity + scaling O via HBM offloading.By leveraging this approach, we can achieve better performance for large headdim (D > 256) through a balanced utilization of FlashAttention (which is not designed to support D > 256) and SDPA EA. Approximate SRAM and register complexity analysis for L1~L3 is as follows: (`d`=headdim, `C,Br,Bc`=Constant, `Br=Bc`) π
|πComplexity| πFFPA L1 | πFFPA L2 | πFFPA L3 | πFA-2 |
|:---:|:---:|:---:|:---:|:---:|
|SRAM | O(2xBrx16)βO(1) | O(2xBrx16)βO(1) | O(2xBrx16)βO(1) | βO(3xBrxd), dβ |
|Register | βO(d/4), dβ | O((Bc/16)x4+2C)βO(1)|O((Bc/16)x4+2C)βO(1)| βO(d/2), dβ |
|HBM| βFA2 | βFA2 | βFA2 | =FA2 |## π Prerequisites
- Python >= 3.10
- PyTorch >= 2.4.0, CUDA >= 12.4
- Recommended: PyTorch 2.5.1, CUDA 12.5## π Installation
The FFPA implemented in this repo can be install as a python library, namely, `cuffpa-py` library (optional).
```bash
git clone https://github.com/DefTruth/cuffpa-py.git
# clone, then, run bash .dev/install.sh directly or run commands:
python3 setup.py bdist_wheel && rm -rf *.egg-info # build 'cuffpa-py' from sources
cd dist && python3 -m pip install cuffpa_py-*-linux_x86_64.whl # pip uninstall cuffpa-py -y
```## π FFPA L1 (Level 1): Benchmark ππ
L1: level 1, O(2xBrx16)βO(1) SRAM complexity, O(d/4) register complexity, the same GPU HBM memory complexity as FlashAttention. B=1, H=48, N=8192, **D=320-1024(FA2 not supported π)**. (Notes, `*`=MMA Acc F32, `^`=MMA Acc F16, Softmax Acc dtype is always be F32, T=TFLOPS, πBenchmark)
- π NVIDIA RTX 3080 Laptop (`*`=MMA Acc F32, `^`=MMA Acc F16, `T`=TFLOPS)
|Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|
|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
|SDPA EA|13T|16T|12T|16T|15T|15T|15T|15T|15T|15T|15T|15T|
|FFPA L1*|32T|30T|30T|28T|28T|27T|26T|25T|25T|25T|25T|24T|
|Speedup|2.48x|1.88x|2.55x|1.75x|1.90x|1.77x|1.73x|1.67x|1.66x|1.66x|1.66x|1.54x|
|FFPA L1^|40T|38T|39T|36T|35T|34T|33T|32T|31T|31T|28T|27T|
|Speedup|3.07x|2.42x|3.33x|2.24x|2.35x|2.19x|2.19x|2.13x|2.03x|2.03x|1.90x|1.74x|- π NVIDIA RTX 4090 (`*`=MMA Acc F32, `^`=MMA Acc F16, `T`=TFLOPS)
|Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|
|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
|SDPA EA|82T|92T|83T|84T|78T|80T|78T|80T|78T|80T|78T|79T|
|FFPA L1*|136T|135T|135T|132T|133T|133T|132T|131T|130T|125T|123T|93T|
|Speedup|1.64x|1.45x|1.61x|1.57x|1.71x|1.65x|1.68x|1.62x|1.65x|1.56x|1.55x|1.17x|
|FFPA L1^|154T|161T|160T|157T|156T|155T|157T|154T|149T|150T|145T|100T|
|Speedup|1.85x|1.73x|1.92x|1.87x|1.99x|1.93x|1.99x|1.90x|1.90x|1.88x|1.84x|1.25x|- π NVIDIA L20 (`*`=MMA Acc F32, `^`=MMA Acc F16, `T`=TFLOPS)
|Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|
|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
|SDPA EA|56T|63T|57T|58T|55T|56T|54T|55T|54T|55T|54T|56T|
|FFPA L1*|99T|95T|95T|93T|94T|92T|92T|90T|89T|90T|90T|89T|
|Speedup|1.77x|1.49x|1.64x|1.58x|1.72x|1.65x|1.68x|1.63x|1.64x|1.63x|1.67x|1.58x|
|FFPA L1^|96T|99T|100T|92T|93T|92T|93T|91T|90T|90T|88T|91T|
|Speedup|1.71x|1.55x|1.73x|1.56x|1.69x|1.65x|1.71x|1.64x|1.65x|1.63x|1.62x|1.62x|- π NVIDIA A30 (`*`=MMA Acc F32, `^`=MMA Acc F16, `T`=TFLOPS)
|Algorithm|320|384|448|512|576|640|704|768|832|896|960|1024|
|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|:---:|
|SDPA EA|25T|25T|24T|23T|24T|24T|23T|22T|22T|21T|21T|18T|
|FFPA L1*|33T|33T|32T|31T|32T|32T|30T|28T|25T|24T|24T|24T|
|Speedup|1.33x|1.33x|1.30x|1.31x|1.33x|1.33x|1.32x|1.23x|1.15x|1.11x|1.11x|1.27x|
|FFPA L1^|33T|33T|33T|30T|31T|32T|31T|30T|30T|27T|24T|23T|
|Speedup|1.33x|1.33x|1.36x|1.30x|1.31x|1.33x|1.37x|1.35x|1.35x|1.25x|1.11x|1.25x|## π Python Testing
π You can test many custom FFPA kernels via Python and figure out the difference in their performance.
```bash
# You can test Ada or Ampere only, also, Volta, Ampere, Ada, Hopper, ...
export TORCH_CUDA_ARCH_LIST=Ada # for Ada only
export TORCH_CUDA_ARCH_LIST=Ampere # for Ampere only
cd tests && python3 test.py --B 1 --H 48 --N 8192 --show-all --D 320
```
- π case: B=1, H=48, N=8192, D=320(`FA2 not supported`), Device=NVIDIA RTX 4090.
```bash
python3 test.py --B 1 --H 48 --N 8192 --show-all --D 320 # NVIDIA RTX 4090
-------------------------------------------------------------------------------------------------
-----------------------------B=1, H=48, N=8192, D=320, Warmup: 1, Iters: 5-----------------------
(sdpa): ['-0.01750183 '], time:50.36ms, TFLOPS:82.19 (+0.00 %)(~1.00x)
(ffpa+acc+f32+L1+stage1): ['-0.01754761 '], time:40.23ms, TFLOPS:102.87(+25.17%)(~1.25x)
(ffpa+acc+f32+L1+stage2): ['-0.01754761 '], time:30.35ms, TFLOPS:136.34(+32.54%)(~1.66x)
(ffpa+acc+f16+L1+stage1): ['-0.01747131 '], time:31.03ms, TFLOPS:133.27(+0.00 %)(~1.62x)
(ffpa+acc+f16+L1+stage2): ['-0.01747131 '], time:26.98ms, TFLOPS:153.41(+12.51%)(~1.87x)
-------------------------------------------------------------------------------------------------
```## Β©οΈLicense
GNU General Public License v3.0
## πContribute
How to contribute? Wecome to starβοΈ this repo to support meππ» ~
## π References
- [flash-attention](https://github.com/Dao-AILab/flash-attention)
- [CUDA-Learn-Notes](https://github.com/DefTruth/CUDA-Learn-Notes)
- [flashinfer](https://github.com/flashinfer-ai/flashinfer)