{"id":48829973,"url":"https://github.com/sbaumohl/canonical-interp","last_synced_at":"2026-04-17T22:01:20.613Z","repository":{"id":349478789,"uuid":"1200972758","full_name":"sbaumohl/canonical-interp","owner":"sbaumohl","description":null,"archived":false,"fork":false,"pushed_at":"2026-04-06T05:02:53.000Z","size":279,"stargazers_count":1,"open_issues_count":0,"forks_count":0,"subscribers_count":0,"default_branch":"main","last_synced_at":"2026-04-06T07:18:13.942Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":null,"language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/sbaumohl.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null,"notice":null,"maintainers":null,"copyright":null,"agents":null,"dco":null,"cla":null}},"created_at":"2026-04-04T03:40:53.000Z","updated_at":"2026-04-06T05:03:00.000Z","dependencies_parsed_at":null,"dependency_job_id":null,"html_url":"https://github.com/sbaumohl/canonical-interp","commit_stats":null,"previous_names":["sbaumohl/canonical-interp"],"tags_count":null,"template":false,"template_full_name":null,"purl":"pkg:github/sbaumohl/canonical-interp","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/sbaumohl%2Fcanonical-interp","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/sbaumohl%2Fcanonical-interp/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/sbaumohl%2Fcanonical-interp/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/sbaumohl%2Fcanonical-interp/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/sbaumohl","download_url":"https://codeload.github.com/sbaumohl/canonical-interp/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/sbaumohl%2Fcanonical-interp/sbom","scorecard":null,"host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":286080680,"owners_count":31812977,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2026-04-14T18:05:02.291Z","status":"ssl_error","status_checked_at":"2026-04-14T18:05:01.765Z","response_time":153,"last_error":"SSL_read: unexpected eof while reading","robots_txt_status":"success","robots_txt_updated_at":"2025-07-24T06:49:26.215Z","robots_txt_url":"https://github.com/robots.txt","online":false,"can_crawl_api":true,"host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2026-04-14T20:00:19.716Z","updated_at":"2026-04-14T20:00:29.747Z","avatar_url":"https://github.com/sbaumohl.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# Canonical Interp: Efficient Developmental Interpretability\n\n[![PyPI version](https://img.shields.io/pypi/v/canonical-interp)](https://pypi.org/project/canonical-interp/)\n\n\u003e **Note:** This package is under active development (0.X.Y). Breaking changes should be expected between minor versions.\n\nA lean, efficient Local Learning Coefficient (LLC / RLCT) estimator. Rewrite of [Timaeus's `devinterp`](https://github.com/timaeus-research/devinterp). MIT licensed.\n\n## What is the LLC?\n\nThe Local Learning Coefficient is a measure of effective model complexity at a specific point in weight space, grounded in Singular Learning Theory. For a model trained to a loss minimum $w^{\\ast}$, the LLC estimates the real log canonical threshold (RLCT) $\\hat\\lambda$:\n\n$$\\hat\\lambda = n\\beta \\cdot (\\bar L_{\\text{SGLD}} - L_0)$$\n\nwhere $L_0$ is the loss at $w^{\\ast}$, $\\bar L_{\\text{SGLD}}$ is the time-averaged loss of SGLD chains run from $w^{\\ast}$, and $n\\beta$ is an inverse-temperature factor. Higher LLC means the model is using more of its parameter space at that point. Lower LLC signals degeneracy or symmetry.\n\n## Installation\n\n```bash\npip install canonical-interp\n# or with uv:\nuv add canonical-interp\n```\n\nRequires Python ≥ 3.12 and PyTorch ≥ 2.8.\n\n## Quick start\n\n```python\nimport torch\nfrom torch.utils.data import DataLoader\nfrom canonical_interp.slt import LLCEstimator\n\nn = len(train_dataset)\nnbeta = n / math.log(n)  # standard SLT choice\n\nestimator = LLCEstimator(\n    draws=200,\n    chains=10,\n    burnin_steps=100,\n    steps_bw_draws=1,\n    learning_rate=1e-5,\n    localization=100.0,\n    nbeta=nbeta,\n)\n\nloader = DataLoader(train_dataset, batch_size=512, shuffle=False,\n                    pin_memory=True, num_workers=4, persistent_workers=True)\n\nllc = estimator.estimate_llc(model, loader)\nprint(f\"LLC: {llc.mean():.4f}  (per chain: {llc.tolist()})\")\n```\n\n`criterion_fn` defaults to `F.cross_entropy`. For classification tasks you don't need to pass anything extra.\n\n## Examples\n\n### Custom loss (regression, custom architectures)\n\nSupply any `(logits, targets) -\u003e scalar` function. The library handles the `functional_call` wrapping internally, so you don't write it.\n\n```python\nimport torch.nn.functional as F\n\nllc = estimator.estimate_llc(model, loader, criterion_fn=F.mse_loss)\n```\n\nFor more control (e.g. label smoothing, auxiliary losses), pass a lambda or a named function:\n\n```python\ndef my_loss(logits, targets):\n    return F.cross_entropy(logits, targets, label_smoothing=0.1)\n\nllc = estimator.estimate_llc(model, loader, criterion_fn=my_loss)\n```\n\n### Custom dataloader format\n\nIf your DataLoader yields dicts or tuples with more than two elements, pass an `unpack_fn` to extract `(x, y)`. Device movement is handled internally.\n\n```python\n# DataLoader yields {\"input_ids\": ..., \"labels\": ...}\nllc = estimator.estimate_llc(\n    model, loader,\n    unpack_fn=lambda batch: (batch[\"input_ids\"], batch[\"labels\"]),\n)\n```\n\n### Mixed precision (bfloat16 / float16)\n\nPass `dtype` to the constructor. Autocast is applied to the forward pass before `vmap` and `torch.compile` see it, so the compiler can fuse dtype casts into the rest of the graph.\n\n```python\nestimator = LLCEstimator(\n    draws=200, chains=10, burnin_steps=100, steps_bw_draws=1,\n    learning_rate=1e-5, localization=100.0, nbeta=nbeta,\n    dtype=torch.bfloat16,  # or torch.float16\n)\n\nllc = estimator.estimate_llc(model, loader, devices=\"cuda\")\n```\n\n### Multi-GPU\n\nPass a list of devices and set `chain_batch` to control how many chains run on each device at once. Chain batches are distributed round-robin and run in parallel via a `ThreadPoolExecutor`.\n\n```python\nllc = estimator.estimate_llc(\n    model, loader,\n    devices=[\"cuda:0\", \"cuda:1\"],\n    chain_batch=5,   # 5 chains per device call; 10 chains total = 2 calls\n)\n```\n\n### Reproducibility\n\n```python\nllc = estimator.estimate_llc(model, loader, seed=42)\n```\n\nA master seed is used to derive independent per-device seeds deterministically, so results are reproducible regardless of how chains are batched across devices.\n\n### Accessing per-chain metrics\n\nAfter calling `estimate_llc`, use `get_metrics()` to retrieve the full loss trace and per-chain LLC values:\n\n```python\nllc = estimator.estimate_llc(model, loader)\nmetrics = estimator.get_metrics()\n\nmetrics[\"llc_mean\"]     # scalar: average LLC across chains\nmetrics[\"llc_std\"]      # scalar: std of per-chain LLCs\nmetrics[\"llcs\"]         # [chains]: per-chain LLC estimates\nmetrics[\"losses\"]       # [chains, draws]: full loss trace per chain\nmetrics[\"losses_mean\"]  # scalar: mean of final-draw losses\nmetrics[\"losses_std\"]   # scalar: std of final-draw losses\n```\n\n### Hyperparameter grid search\n\nWhen tuning `epsilon`, `gamma`, and `nbeta`, use `LLCGridSearch` to sweep over ranges and compare results in a single DataFrame.\n\n```python\nfrom canonical_interp import LLCGridSearch, GridSearchConfig\n\ncfg = GridSearchConfig(\n    epsilon=(1e-6, 1e-4),   # range to sweep\n    gamma=(10.0, 500.0),    # range to sweep\n    nbeta=nbeta,            # fixed value\n    estimates_per_dim=5,    # 5 values per range, so 5x5x1 = 25 runs\n    draws=200,\n    chains=8,\n    burnin_steps=100,\n)\n\ngs = LLCGridSearch(cfg)\ndf = gs.run_grid_search(model, loader, devices=\"cuda\")\nprint(df)\n#   epsilon   gamma   nbeta  llc_mean  llc_std  loss_mean  loss_std\n# 0  1e-06    10.0    57.6     2.31     0.12      0.041     0.003\n# 1  1e-06   132.5    57.6     2.45     0.09      0.039     0.002\n# ...\n```\n\nAll options supported by `LLCEstimator.estimate_llc` (multi-GPU, compilation, custom loss, `unpack_fn`, seed) pass through to each grid point.\n\n## Performance by default\n\nLLC estimation is compute-intensive: it requires thousands of forward+backward passes across many parallel chains. This library is built around the principle that sensible defaults should leave no performance on the table.\n\n| Feature | What it does |\n|---|---|\n| **`vmap` over chains** | All chains in a batch run as a single fused kernel with no Python loop and no per-chain overhead |\n| **`torch.compile`** | The vmapped grad function is JIT-compiled by default (`compile=True`) |\n| **Autocast inside the transform** | Autocast wraps the forward pass *before* `vmap`/`grad` see it, so the compiler can fuse dtype casts rather than treating them as opaque boundaries |\n| **DataLoader warnings** | The estimator warns when `pin_memory=False`, `num_workers=0`, or `persistent_workers=False`, all of which cause avoidable stalls when the dataloader is cycled for the full SGLD run |\n\n### Recommended DataLoader settings for GPU runs\n\n```python\nloader = DataLoader(\n    dataset,\n    batch_size=512,\n    shuffle=False,       # non-shuffled for deterministic LLC; shuffle for training\n    pin_memory=True,     # faster CPU→GPU transfer\n    num_workers=4,       # overlap data loading with GPU compute\n    persistent_workers=True,  # avoid worker respawn between SGLD epochs\n)\n```\n\n## Hyperparameter guide\n\n| Parameter | Typical range | Effect |\n|---|---|---|\n| `nbeta` | `n / log(n)` | Inverse temperature; scales the LLC estimate. Use `nbeta_from_effective_size(n)` or compute directly. |\n| `learning_rate` | `1e-6` – `1e-4` | SGLD step size. Too large → divergence; too small → chain doesn't move. |\n| `localization` | `1` – `1000` | Elastic pull toward initial weights. Higher values keep chains near $w^{\\ast}$, giving tighter LLC estimates. |\n| `burnin_steps` | 100 – 500 | Steps discarded before draws. Should cover transient behaviour. |\n| `draws` | 100 – 500 | Samples per chain used to estimate $\\bar L$. More draws → lower variance. |\n| `chains` | 5 – 20 | Independent chains. More chains → lower variance; all run in parallel via vmap. |\n\n## How it works\n\nAll chain parameters are stacked into a single batched tensor and a single functional forward+backward is vmapped over them, mapping naturally onto GPU parallelism. The statistical procedure follows the SGLD-based LLC estimator from Singular Learning Theory. Results can be validated against known closed-form RLCTs (see `test_known_rlct.py`).\n\nThis library was built on top of ideas from [devinterp](https://github.com/timaeus-research/devinterp).\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fsbaumohl%2Fcanonical-interp","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fsbaumohl%2Fcanonical-interp","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fsbaumohl%2Fcanonical-interp/lists"}