{"id":51440665,"url":"https://github.com/daedalus/ssa","last_synced_at":"2026-07-05T11:01:47.775Z","repository":{"id":369136505,"uuid":"1285299886","full_name":"daedalus/SSA","owner":"daedalus","description":"O(N·K) multi-head attention for PyTorch — a sparse drop-in replacement for dense scaled-dot-product attention.","archived":false,"fork":false,"pushed_at":"2026-07-03T18:22:06.000Z","size":269,"stargazers_count":0,"open_issues_count":0,"forks_count":0,"subscribers_count":0,"default_branch":"master","last_synced_at":"2026-07-03T20:20:15.372Z","etag":null,"topics":["attention-mechanism","deep-learning","efficient-computation","long-context","lsh","machine-learning","pytorch","sparse-attention","transformer"],"latest_commit_sha":null,"homepage":null,"language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/daedalus.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null,"notice":null,"maintainers":null,"copyright":null,"agents":null,"dco":null,"cla":null}},"created_at":"2026-06-30T17:04:46.000Z","updated_at":"2026-07-03T18:22:09.000Z","dependencies_parsed_at":null,"dependency_job_id":null,"html_url":"https://github.com/daedalus/SSA","commit_stats":null,"previous_names":["daedalus/ssa"],"tags_count":null,"template":false,"template_full_name":null,"purl":"pkg:github/daedalus/SSA","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/daedalus%2FSSA","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/daedalus%2FSSA/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/daedalus%2FSSA/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/daedalus%2FSSA/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/daedalus","download_url":"https://codeload.github.com/daedalus/SSA/tar.gz/refs/heads/master","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/daedalus%2FSSA/sbom","scorecard":null,"host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":286080680,"owners_count":35151638,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2026-05-26T15:22:16.424Z","status":"online","status_checked_at":"2026-07-05T02:00:06.290Z","response_time":100,"last_error":null,"robots_txt_status":"success","robots_txt_updated_at":"2025-07-24T06:49:26.215Z","robots_txt_url":"https://github.com/robots.txt","online":true,"can_crawl_api":true,"host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["attention-mechanism","deep-learning","efficient-computation","long-context","lsh","machine-learning","pytorch","sparse-attention","transformer"],"created_at":"2026-07-05T11:01:46.899Z","updated_at":"2026-07-05T11:01:47.755Z","avatar_url":"https://github.com/daedalus.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# SparseAttention\n\nO(N·K) multi-head attention for PyTorch — a sparse drop-in replacement for\ndense scaled-dot-product attention. Combines a local window, a fixed set of\nglobal tokens, and content-based LSH routing into a single per-query\nneighbor list, then attends only to that list instead of the full sequence.\n\n## Quick Start\n\n### Installation\n\nNo extra dependencies — just PyTorch:\n\n```bash\npip install torch\n```\n\nThen copy `sparse_attention.py` into your project, or clone the repo:\n\n```bash\ngit clone https://github.com/daedalus/SSA.git\ncd SSA\n```\n\n### Basic Usage\n\n```python\nimport torch\nfrom sparse_attention import SSAConfig, SparseAttention\n\n# Configure and create the attention module\ncfg = SSAConfig(d_model=512, num_heads=8, num_neighbors=128,\n                window_size=8, num_global_tokens=2, causal=True)\nattn = SparseAttention(cfg)\n\n# Self-attention\nx = torch.randn(2, 1024, 512)  # (batch, seq_len, d_model)\nout, _ = attn(x)\n\n# Cross-attention (e.g., encoder-decoder)\nenc_out = torch.randn(2, 512, 512)  # encoder output\nout, _ = attn(x, key_value=enc_out)\n```\n\n### Grouped-Query Attention (GQA / MQA)\n\n```python\n# Standard MHA (default)\ncfg = SSAConfig(num_heads=8)\n\n# Grouped-query attention (LLaMA-2 / Mistral style)\ncfg = SSAConfig(num_heads=32, num_kv_heads=8)  # 4 query heads share each KV head\n\n# Multi-query attention (single shared KV head, max memory savings)\ncfg = SSAConfig(num_heads=32, num_kv_heads=1)\n```\n\n### Full Transformer\n\n```python\nfrom sparse_attention import SparseTransformer\n\nmodel = SparseTransformer(cfg, num_layers=6, vocab_size=32000)\n\ntoken_ids = torch.randint(0, 32000, (2, 1024))\nout, stats_per_layer = model(token_ids)                     # no stats\nout, stats_per_layer = model(token_ids, return_stats=True)  # with stats\n```\n\n### Explicit Global Token Indices\n\n```python\n# Force attention to BOS, CLS, and a mid-document landmark\ncfg = SSAConfig(global_token_indices=[0, 1, 512])\n```\n\n## Configuration\n\nAll behavior is controlled by `SSAConfig`:\n\n```python\n@dataclass\nclass SSAConfig:\n    d_model: int = 512\n    num_heads: int = 8\n    num_kv_heads: Optional[int] = None       # GQA/MQA\n    num_neighbors: int = 128                 # K: total neighbor slots per query\n    max_num_hashes: int = 12                 # ceiling on LSH planes (2^P buckets)\n    num_hash_rounds: int = 8                 # independent hash rounds, unioned\n    lsh_num_probes: int = 0                  # multi-probe: extra near-boundary buckets\n    window_size: int = 8                     # local window half-width\n    num_global_tokens: int = 2               # leading key tokens, all queries attend\n    global_token_indices: Optional[list] = None  # explicit global positions\n    dropout: float = 0.0\n    causal: bool = False\n    fp32_attn_weights: bool = False          # keep post-softmax weights in FP32\n```\n\n### Picking `num_neighbors` vs `window_size` + `num_global_tokens`\n\nKeep the guaranteed budget (`2*window_size + 1 + num_global_tokens + 1`)\ncomfortably under `num_neighbors`. A good rule of thumb for `window_size`:\n`max(1, K // 8)`.\n\n## How It Works\n\nFor each query token, the candidate key set is the union of four sources:\n\n1. **Self** — every token always attends to itself\n2. **Window** — the `2·window_size + 1` nearest positions (causal: only the trailing half)\n3. **Global** — a fixed set of key positions every query attends to\n4. **LSH** — content-based candidates found via multi-round, multi-plane locality-sensitive hashing\n\nThe LSH bucket count is **adaptive** — computed from sequence length to keep average bucket occupancy near the per-round candidate budget, maximizing recall across sequence lengths from 64 to 32,768 tokens.\n\n## Recall and Quality\n\nBenchmarked against exact dense top-K attention on random embeddings:\n\n```\nN=1024, K=128, R=8, window=8, true_k=32\nSparse pipeline recall@32: 99.3%\nRandom-K-selection recall: 12.5%  (floor)\nRatio vs random floor: 7.94x\n```\n\nRecall across sequence lengths:\n\n```\n     N      Before      After       Gain\n-------------------------------------------\n    64      89.5%      99.2%      +9.7%\n   128      85.9%      99.1%     +13.2%\n   256      83.7%      99.6%     +16.0%\n   512      73.3%      99.9%     +26.6%\n  1024      54.7%      99.9%     +45.1%\n  2048      38.1%      99.3%     +61.1%\n  4096      27.0%      95.3%     +68.3%\n```\n\n## Memory Efficiency\n\nPeak LSH rescore memory (K=128, R=8, os=8):\n\n```\n     N      Before      After      Saved\n-------------------------------------------\n   512      0.98 GB     0.14 GB    7×\n  1024      1.97 GB     0.28 GB    7×\n  2048      3.93 GB     0.56 GB    7×\n  4096      7.87 GB     1.12 GB    7×\n```\n\n## Benchmarks\n\n```bash\npython benchmarks/bench_dense_vs_sparse.py\n```\n\n**CPU performance (single core):**\n\n```\n     N |   Dense (ms) |  Sparse (ms) |  Speedup | Mem ratio\n   256 |         3.94 |        62.39 |    0.06x |       2.0x\n   512 |         7.08 |       126.09 |    0.06x |       4.0x\n  1024 |        17.44 |       259.18 |    0.07x |       8.0x\n  2048 |        56.02 |       544.34 |    0.10x |      16.0x\n  4096 |       202.81 |      1131.41 |    0.18x |      32.0x\n  8192 |       782.37 |      2467.90 |    0.32x |      64.0x\n```\n\nDense is currently faster on CPU at every N tested, though the gap narrows sharply as N grows. The memory savings (up to 64×) become the dominant advantage for long sequences. On GPU with custom kernels, sparse attention's O(N·K) complexity provides both speed and memory wins.\n\n## What's NOT Supported\n\n- **External attention masks** — use `config.causal=True` for autoregressive masking; for padding, zero out positions before calling forward\n- **KV caching for incremental decoding** — the full graph is rebuilt on every forward call\n- **Exact recall guarantees** — LSH is approximate, not exact top-k\n\n## Testing\n\n```bash\npip install pytest\npytest\n```\n\n| Test File | Covers |\n|-----------|--------|\n| `test_attention_shapes.py` | Self/cross-attention shape correctness |\n| `test_neighbor_graph.py` | Deduplication, self-edge placement |\n| `test_causal.py` | No future-token leakage |\n| `test_gradients.py` | Gradient flow through sparse operations |\n| `test_quality_and_scaling.py` | Memory scaling, recall@K benchmarks |\n| `test_gqa_mqa.py` | Grouped/multi-query attention |\n| `test_global_tokens_and_mask_guard.py` | Explicit global indices, mask rejection |\n| `test_multiprobe_lsh.py` | Multi-probe LSH correctness |\n| `test_build_apply_graph_split.py` | build_graph/apply_graph consistency |\n| `test_cached_graph_transformer.py` | Cross-layer graph caching |\n\n## License\n\nMIT — see [LICENSE](LICENSE) for details.\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fdaedalus%2Fssa","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fdaedalus%2Fssa","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fdaedalus%2Fssa/lists"}