https://github.com/daedalus/ssa
O(N·K) multi-head attention for PyTorch — a sparse drop-in replacement for dense scaled-dot-product attention.
https://github.com/daedalus/ssa
attention-mechanism deep-learning efficient-computation long-context lsh machine-learning pytorch sparse-attention transformer
Last synced: about 10 hours ago
JSON representation
O(N·K) multi-head attention for PyTorch — a sparse drop-in replacement for dense scaled-dot-product attention.
- Host: GitHub
- URL: https://github.com/daedalus/ssa
- Owner: daedalus
- License: mit
- Created: 2026-06-30T17:04:46.000Z (5 days ago)
- Default Branch: master
- Last Pushed: 2026-07-03T18:22:06.000Z (2 days ago)
- Last Synced: 2026-07-03T20:20:15.372Z (2 days ago)
- Topics: attention-mechanism, deep-learning, efficient-computation, long-context, lsh, machine-learning, pytorch, sparse-attention, transformer
- Language: Python
- Size: 263 KB
- Stars: 0
- Watchers: 0
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# SparseAttention
O(N·K) multi-head attention for PyTorch — a sparse drop-in replacement for
dense scaled-dot-product attention. Combines a local window, a fixed set of
global tokens, and content-based LSH routing into a single per-query
neighbor list, then attends only to that list instead of the full sequence.
## Quick Start
### Installation
No extra dependencies — just PyTorch:
```bash
pip install torch
```
Then copy `sparse_attention.py` into your project, or clone the repo:
```bash
git clone https://github.com/daedalus/SSA.git
cd SSA
```
### Basic Usage
```python
import torch
from sparse_attention import SSAConfig, SparseAttention
# Configure and create the attention module
cfg = SSAConfig(d_model=512, num_heads=8, num_neighbors=128,
window_size=8, num_global_tokens=2, causal=True)
attn = SparseAttention(cfg)
# Self-attention
x = torch.randn(2, 1024, 512) # (batch, seq_len, d_model)
out, _ = attn(x)
# Cross-attention (e.g., encoder-decoder)
enc_out = torch.randn(2, 512, 512) # encoder output
out, _ = attn(x, key_value=enc_out)
```
### Grouped-Query Attention (GQA / MQA)
```python
# Standard MHA (default)
cfg = SSAConfig(num_heads=8)
# Grouped-query attention (LLaMA-2 / Mistral style)
cfg = SSAConfig(num_heads=32, num_kv_heads=8) # 4 query heads share each KV head
# Multi-query attention (single shared KV head, max memory savings)
cfg = SSAConfig(num_heads=32, num_kv_heads=1)
```
### Full Transformer
```python
from sparse_attention import SparseTransformer
model = SparseTransformer(cfg, num_layers=6, vocab_size=32000)
token_ids = torch.randint(0, 32000, (2, 1024))
out, stats_per_layer = model(token_ids) # no stats
out, stats_per_layer = model(token_ids, return_stats=True) # with stats
```
### Explicit Global Token Indices
```python
# Force attention to BOS, CLS, and a mid-document landmark
cfg = SSAConfig(global_token_indices=[0, 1, 512])
```
## Configuration
All behavior is controlled by `SSAConfig`:
```python
@dataclass
class SSAConfig:
d_model: int = 512
num_heads: int = 8
num_kv_heads: Optional[int] = None # GQA/MQA
num_neighbors: int = 128 # K: total neighbor slots per query
max_num_hashes: int = 12 # ceiling on LSH planes (2^P buckets)
num_hash_rounds: int = 8 # independent hash rounds, unioned
lsh_num_probes: int = 0 # multi-probe: extra near-boundary buckets
window_size: int = 8 # local window half-width
num_global_tokens: int = 2 # leading key tokens, all queries attend
global_token_indices: Optional[list] = None # explicit global positions
dropout: float = 0.0
causal: bool = False
fp32_attn_weights: bool = False # keep post-softmax weights in FP32
```
### Picking `num_neighbors` vs `window_size` + `num_global_tokens`
Keep the guaranteed budget (`2*window_size + 1 + num_global_tokens + 1`)
comfortably under `num_neighbors`. A good rule of thumb for `window_size`:
`max(1, K // 8)`.
## How It Works
For each query token, the candidate key set is the union of four sources:
1. **Self** — every token always attends to itself
2. **Window** — the `2·window_size + 1` nearest positions (causal: only the trailing half)
3. **Global** — a fixed set of key positions every query attends to
4. **LSH** — content-based candidates found via multi-round, multi-plane locality-sensitive hashing
The LSH bucket count is **adaptive** — computed from sequence length to keep average bucket occupancy near the per-round candidate budget, maximizing recall across sequence lengths from 64 to 32,768 tokens.
## Recall and Quality
Benchmarked against exact dense top-K attention on random embeddings:
```
N=1024, K=128, R=8, window=8, true_k=32
Sparse pipeline recall@32: 99.3%
Random-K-selection recall: 12.5% (floor)
Ratio vs random floor: 7.94x
```
Recall across sequence lengths:
```
N Before After Gain
-------------------------------------------
64 89.5% 99.2% +9.7%
128 85.9% 99.1% +13.2%
256 83.7% 99.6% +16.0%
512 73.3% 99.9% +26.6%
1024 54.7% 99.9% +45.1%
2048 38.1% 99.3% +61.1%
4096 27.0% 95.3% +68.3%
```
## Memory Efficiency
Peak LSH rescore memory (K=128, R=8, os=8):
```
N Before After Saved
-------------------------------------------
512 0.98 GB 0.14 GB 7×
1024 1.97 GB 0.28 GB 7×
2048 3.93 GB 0.56 GB 7×
4096 7.87 GB 1.12 GB 7×
```
## Benchmarks
```bash
python benchmarks/bench_dense_vs_sparse.py
```
**CPU performance (single core):**
```
N | Dense (ms) | Sparse (ms) | Speedup | Mem ratio
256 | 3.94 | 62.39 | 0.06x | 2.0x
512 | 7.08 | 126.09 | 0.06x | 4.0x
1024 | 17.44 | 259.18 | 0.07x | 8.0x
2048 | 56.02 | 544.34 | 0.10x | 16.0x
4096 | 202.81 | 1131.41 | 0.18x | 32.0x
8192 | 782.37 | 2467.90 | 0.32x | 64.0x
```
Dense is currently faster on CPU at every N tested, though the gap narrows sharply as N grows. The memory savings (up to 64×) become the dominant advantage for long sequences. On GPU with custom kernels, sparse attention's O(N·K) complexity provides both speed and memory wins.
## What's NOT Supported
- **External attention masks** — use `config.causal=True` for autoregressive masking; for padding, zero out positions before calling forward
- **KV caching for incremental decoding** — the full graph is rebuilt on every forward call
- **Exact recall guarantees** — LSH is approximate, not exact top-k
## Testing
```bash
pip install pytest
pytest
```
| Test File | Covers |
|-----------|--------|
| `test_attention_shapes.py` | Self/cross-attention shape correctness |
| `test_neighbor_graph.py` | Deduplication, self-edge placement |
| `test_causal.py` | No future-token leakage |
| `test_gradients.py` | Gradient flow through sparse operations |
| `test_quality_and_scaling.py` | Memory scaling, recall@K benchmarks |
| `test_gqa_mqa.py` | Grouped/multi-query attention |
| `test_global_tokens_and_mask_guard.py` | Explicit global indices, mask rejection |
| `test_multiprobe_lsh.py` | Multi-probe LSH correctness |
| `test_build_apply_graph_split.py` | build_graph/apply_graph consistency |
| `test_cached_graph_transformer.py` | Cross-layer graph caching |
## License
MIT — see [LICENSE](LICENSE) for details.