https://github.com/alepot55/flash-reasoning
Tree-Aware Attention for System 2 Reasoning. Reduces KV-Cache VRAM by 96% and exceeds physical HBM bandwidth (1.33x) via Fused GQA Triton kernels.
https://github.com/alepot55/flash-reasoning
attention-mechanism fused-gqa kv-cache reasoning triton
Last synced: 5 months ago
JSON representation
Tree-Aware Attention for System 2 Reasoning. Reduces KV-Cache VRAM by 96% and exceeds physical HBM bandwidth (1.33x) via Fused GQA Triton kernels.
- Host: GitHub
- URL: https://github.com/alepot55/flash-reasoning
- Owner: alepot55
- License: mit
- Created: 2026-01-26T22:40:38.000Z (5 months ago)
- Default Branch: master
- Last Pushed: 2026-01-26T23:37:06.000Z (5 months ago)
- Last Synced: 2026-01-27T09:42:38.508Z (5 months ago)
- Topics: attention-mechanism, fused-gqa, kv-cache, reasoning, triton
- Language: Python
- Homepage:
- Size: 480 KB
- Stars: 0
- Watchers: 0
- Forks: 0
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- Contributing: CONTRIBUTING.md
- License: LICENSE
Awesome Lists containing this project
README
⚡ Flash-Reasoning
Tree-Aware KV-Cache Attention for Reasoning LLMs
---
## The Problem
Reasoning LLMs (DeepSeek-R1, o1, Tree-of-Thought) generate **decision trees** during inference. Standard engines treat sequences linearly → **O(n × b)** memory waste when `b` branches share a prefix.
```mermaid
flowchart LR
subgraph Standard["❌ Standard (vLLM)"]
direction TB
B1["Branch 1: [Prefix | Suffix]"] --> M1["2112 tokens"]
B2["Branch 2: [Prefix | Suffix]"] --> M2["2112 tokens"]
B3["Branch 256: [Prefix | Suffix]"] --> M3["2112 tokens"]
end
subgraph Tree["✅ Flash-Reasoning"]
direction TB
P["Shared Prefix
2048 tokens (1×)"] --> S1["Suffix 1
64 tok"]
P --> S2["Suffix 2
64 tok"]
P --> S3["Suffix 256
64 tok"]
end
Standard -.->|"541K tokens"| X["❌"]
Tree -.->|"18K tokens"| Y["✅ 29× less"]
```
---
## Results
**2.54× faster** | **96.6% less VRAM** | **L2 cache exploitation**
### Speedup
### Memory Bandwidth
> Effective bandwidth exceeds HBM limit (900 GB/s) because shared prefix blocks hit L2 cache (~5 TB/s).
### VRAM Reduction
### Benchmark Table
| Batch | Tree | Linear | Speedup | VRAM Reduction |
|------:|-----:|-------:|--------:|---------------:|
| 1 | 0.089 ms | 0.091 ms | 1.02× | 3% |
| 16 | 0.184 ms | 0.301 ms | 1.64× | 74% |
| 64 | 0.483 ms | 1.102 ms | 2.28× | 92% |
| 128 | 0.912 ms | 2.281 ms | 2.50× | 95% |
| **256** | **1.859 ms** | **4.729 ms** | **2.54×** | **96.6%** |
---
## Features
| Feature | Description |
|---------|-------------|
| **Physical Prefix Sharing** | Branches share KV blocks via reference counting |
| **Fused GQA Kernel** | K/V loaded once per KV-head group (4-8× traffic reduction) |
| **Online Softmax** | FlashAttention-style O(1) memory per query |
| **Triton Autotuning** | Automatic optimization for A100/H100/RTX |
---
## Installation
```bash
git clone https://github.com/alepot55/flash-reasoning.git
cd flash-reasoning
uv sync --all-extras
```
---
## Quick Start
```python
import torch
from flash_reasoning import PhysicalKVAllocator, tree_attention
# Initialize allocator
allocator = PhysicalKVAllocator(
num_blocks=1024, block_size=16, num_kv_heads=8, head_dim=128, device="cuda"
)
# Allocate shared prefix
root = allocator.alloc_branch(num_tokens=2048)
# Fork into 256 branches (all share prefix)
branches = [
allocator.alloc_branch(2112, parent_branch_id=root.branch_id, fork_position=2048)
for _ in range(256)
]
# Compute attention
q = torch.randn(256, 32, 128, device="cuda", dtype=torch.float16)
output = tree_attention(q, allocator, [b.branch_id for b in branches])
```
---
## Architecture
```
src/flash_reasoning/
├── core/memory.py # PhysicalKVAllocator (block alloc + refcount)
├── kernels/tree_attention.py # Triton kernels (fused GQA + online softmax)
└── ops/attention.py # tree_attention() wrapper
```
---
## Benchmarks
```bash
uv run python benchmarks/benchmark_throughput.py --batch-sizes 64 128 256
```
---
## Tests
```bash
uv run pytest tests/ -v
# 9 passed ✓
```
---
## Citation
```bibtex
@software{flash_reasoning_2025,
title = {Flash-Reasoning: Tree-Aware KV-Cache Attention},
author = {Potenza, Alessandro},
year = {2025},
url = {https://github.com/alepot55/flash-reasoning}
}
```
---
## Related Work
- [FlashAttention](https://github.com/Dao-AILab/flash-attention) — IO-aware attention
- [vLLM](https://github.com/vllm-project/vllm) — PagedAttention
- [SGLang](https://github.com/sgl-project/sglang) — RadixAttention
---
MIT License • Built with Triton