An open API service indexing awesome lists of open source software.

https://github.com/alepot55/flash-reasoning

Tree-Aware Attention for System 2 Reasoning. Reduces KV-Cache VRAM by 96% and exceeds physical HBM bandwidth (1.33x) via Fused GQA Triton kernels.
https://github.com/alepot55/flash-reasoning

attention-mechanism fused-gqa kv-cache reasoning triton

Last synced: 5 months ago
JSON representation

Tree-Aware Attention for System 2 Reasoning. Reduces KV-Cache VRAM by 96% and exceeds physical HBM bandwidth (1.33x) via Fused GQA Triton kernels.

Awesome Lists containing this project

README

          

⚡ Flash-Reasoning


Tree-Aware KV-Cache Attention for Reasoning LLMs


Python 3.10+
PyTorch
Triton
License

---

## The Problem

Reasoning LLMs (DeepSeek-R1, o1, Tree-of-Thought) generate **decision trees** during inference. Standard engines treat sequences linearly → **O(n × b)** memory waste when `b` branches share a prefix.

```mermaid
flowchart LR
subgraph Standard["❌ Standard (vLLM)"]
direction TB
B1["Branch 1: [Prefix | Suffix]"] --> M1["2112 tokens"]
B2["Branch 2: [Prefix | Suffix]"] --> M2["2112 tokens"]
B3["Branch 256: [Prefix | Suffix]"] --> M3["2112 tokens"]
end

subgraph Tree["✅ Flash-Reasoning"]
direction TB
P["Shared Prefix
2048 tokens (1×)"] --> S1["Suffix 1
64 tok"]
P --> S2["Suffix 2
64 tok"]
P --> S3["Suffix 256
64 tok"]
end

Standard -.->|"541K tokens"| X["❌"]
Tree -.->|"18K tokens"| Y["✅ 29× less"]
```

---

## Results

**2.54× faster** | **96.6% less VRAM** | **L2 cache exploitation**

### Speedup


Speedup Chart

### Memory Bandwidth


Bandwidth Chart

> Effective bandwidth exceeds HBM limit (900 GB/s) because shared prefix blocks hit L2 cache (~5 TB/s).

### VRAM Reduction


VRAM Chart

### Benchmark Table

| Batch | Tree | Linear | Speedup | VRAM Reduction |
|------:|-----:|-------:|--------:|---------------:|
| 1 | 0.089 ms | 0.091 ms | 1.02× | 3% |
| 16 | 0.184 ms | 0.301 ms | 1.64× | 74% |
| 64 | 0.483 ms | 1.102 ms | 2.28× | 92% |
| 128 | 0.912 ms | 2.281 ms | 2.50× | 95% |
| **256** | **1.859 ms** | **4.729 ms** | **2.54×** | **96.6%** |

---

## Features

| Feature | Description |
|---------|-------------|
| **Physical Prefix Sharing** | Branches share KV blocks via reference counting |
| **Fused GQA Kernel** | K/V loaded once per KV-head group (4-8× traffic reduction) |
| **Online Softmax** | FlashAttention-style O(1) memory per query |
| **Triton Autotuning** | Automatic optimization for A100/H100/RTX |

---

## Installation

```bash
git clone https://github.com/alepot55/flash-reasoning.git
cd flash-reasoning
uv sync --all-extras
```

---

## Quick Start

```python
import torch
from flash_reasoning import PhysicalKVAllocator, tree_attention

# Initialize allocator
allocator = PhysicalKVAllocator(
num_blocks=1024, block_size=16, num_kv_heads=8, head_dim=128, device="cuda"
)

# Allocate shared prefix
root = allocator.alloc_branch(num_tokens=2048)

# Fork into 256 branches (all share prefix)
branches = [
allocator.alloc_branch(2112, parent_branch_id=root.branch_id, fork_position=2048)
for _ in range(256)
]

# Compute attention
q = torch.randn(256, 32, 128, device="cuda", dtype=torch.float16)
output = tree_attention(q, allocator, [b.branch_id for b in branches])
```

---

## Architecture

```
src/flash_reasoning/
├── core/memory.py # PhysicalKVAllocator (block alloc + refcount)
├── kernels/tree_attention.py # Triton kernels (fused GQA + online softmax)
└── ops/attention.py # tree_attention() wrapper
```

---

## Benchmarks

```bash
uv run python benchmarks/benchmark_throughput.py --batch-sizes 64 128 256
```

---

## Tests

```bash
uv run pytest tests/ -v
# 9 passed ✓
```

---

## Citation

```bibtex
@software{flash_reasoning_2025,
title = {Flash-Reasoning: Tree-Aware KV-Cache Attention},
author = {Potenza, Alessandro},
year = {2025},
url = {https://github.com/alepot55/flash-reasoning}
}
```

---

## Related Work

- [FlashAttention](https://github.com/Dao-AILab/flash-attention) — IO-aware attention
- [vLLM](https://github.com/vllm-project/vllm) — PagedAttention
- [SGLang](https://github.com/sgl-project/sglang) — RadixAttention

---


MIT License • Built with Triton