https://github.com/sbaumohl/canonical-interp
https://github.com/sbaumohl/canonical-interp
Last synced: 3 months ago
JSON representation
- Host: GitHub
- URL: https://github.com/sbaumohl/canonical-interp
- Owner: sbaumohl
- License: mit
- Created: 2026-04-04T03:40:53.000Z (3 months ago)
- Default Branch: main
- Last Pushed: 2026-04-06T05:02:53.000Z (3 months ago)
- Last Synced: 2026-04-06T07:18:13.942Z (3 months ago)
- Language: Python
- Size: 272 KB
- Stars: 1
- Watchers: 0
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Canonical Interp: Efficient Developmental Interpretability
[](https://pypi.org/project/canonical-interp/)
> **Note:** This package is under active development (0.X.Y). Breaking changes should be expected between minor versions.
A lean, efficient Local Learning Coefficient (LLC / RLCT) estimator. Rewrite of [Timaeus's `devinterp`](https://github.com/timaeus-research/devinterp). MIT licensed.
## What is the LLC?
The Local Learning Coefficient is a measure of effective model complexity at a specific point in weight space, grounded in Singular Learning Theory. For a model trained to a loss minimum $w^{\ast}$, the LLC estimates the real log canonical threshold (RLCT) $\hat\lambda$:
$$\hat\lambda = n\beta \cdot (\bar L_{\text{SGLD}} - L_0)$$
where $L_0$ is the loss at $w^{\ast}$, $\bar L_{\text{SGLD}}$ is the time-averaged loss of SGLD chains run from $w^{\ast}$, and $n\beta$ is an inverse-temperature factor. Higher LLC means the model is using more of its parameter space at that point. Lower LLC signals degeneracy or symmetry.
## Installation
```bash
pip install canonical-interp
# or with uv:
uv add canonical-interp
```
Requires Python ≥ 3.12 and PyTorch ≥ 2.8.
## Quick start
```python
import torch
from torch.utils.data import DataLoader
from canonical_interp.slt import LLCEstimator
n = len(train_dataset)
nbeta = n / math.log(n) # standard SLT choice
estimator = LLCEstimator(
draws=200,
chains=10,
burnin_steps=100,
steps_bw_draws=1,
learning_rate=1e-5,
localization=100.0,
nbeta=nbeta,
)
loader = DataLoader(train_dataset, batch_size=512, shuffle=False,
pin_memory=True, num_workers=4, persistent_workers=True)
llc = estimator.estimate_llc(model, loader)
print(f"LLC: {llc.mean():.4f} (per chain: {llc.tolist()})")
```
`criterion_fn` defaults to `F.cross_entropy`. For classification tasks you don't need to pass anything extra.
## Examples
### Custom loss (regression, custom architectures)
Supply any `(logits, targets) -> scalar` function. The library handles the `functional_call` wrapping internally, so you don't write it.
```python
import torch.nn.functional as F
llc = estimator.estimate_llc(model, loader, criterion_fn=F.mse_loss)
```
For more control (e.g. label smoothing, auxiliary losses), pass a lambda or a named function:
```python
def my_loss(logits, targets):
return F.cross_entropy(logits, targets, label_smoothing=0.1)
llc = estimator.estimate_llc(model, loader, criterion_fn=my_loss)
```
### Custom dataloader format
If your DataLoader yields dicts or tuples with more than two elements, pass an `unpack_fn` to extract `(x, y)`. Device movement is handled internally.
```python
# DataLoader yields {"input_ids": ..., "labels": ...}
llc = estimator.estimate_llc(
model, loader,
unpack_fn=lambda batch: (batch["input_ids"], batch["labels"]),
)
```
### Mixed precision (bfloat16 / float16)
Pass `dtype` to the constructor. Autocast is applied to the forward pass before `vmap` and `torch.compile` see it, so the compiler can fuse dtype casts into the rest of the graph.
```python
estimator = LLCEstimator(
draws=200, chains=10, burnin_steps=100, steps_bw_draws=1,
learning_rate=1e-5, localization=100.0, nbeta=nbeta,
dtype=torch.bfloat16, # or torch.float16
)
llc = estimator.estimate_llc(model, loader, devices="cuda")
```
### Multi-GPU
Pass a list of devices and set `chain_batch` to control how many chains run on each device at once. Chain batches are distributed round-robin and run in parallel via a `ThreadPoolExecutor`.
```python
llc = estimator.estimate_llc(
model, loader,
devices=["cuda:0", "cuda:1"],
chain_batch=5, # 5 chains per device call; 10 chains total = 2 calls
)
```
### Reproducibility
```python
llc = estimator.estimate_llc(model, loader, seed=42)
```
A master seed is used to derive independent per-device seeds deterministically, so results are reproducible regardless of how chains are batched across devices.
### Accessing per-chain metrics
After calling `estimate_llc`, use `get_metrics()` to retrieve the full loss trace and per-chain LLC values:
```python
llc = estimator.estimate_llc(model, loader)
metrics = estimator.get_metrics()
metrics["llc_mean"] # scalar: average LLC across chains
metrics["llc_std"] # scalar: std of per-chain LLCs
metrics["llcs"] # [chains]: per-chain LLC estimates
metrics["losses"] # [chains, draws]: full loss trace per chain
metrics["losses_mean"] # scalar: mean of final-draw losses
metrics["losses_std"] # scalar: std of final-draw losses
```
### Hyperparameter grid search
When tuning `epsilon`, `gamma`, and `nbeta`, use `LLCGridSearch` to sweep over ranges and compare results in a single DataFrame.
```python
from canonical_interp import LLCGridSearch, GridSearchConfig
cfg = GridSearchConfig(
epsilon=(1e-6, 1e-4), # range to sweep
gamma=(10.0, 500.0), # range to sweep
nbeta=nbeta, # fixed value
estimates_per_dim=5, # 5 values per range, so 5x5x1 = 25 runs
draws=200,
chains=8,
burnin_steps=100,
)
gs = LLCGridSearch(cfg)
df = gs.run_grid_search(model, loader, devices="cuda")
print(df)
# epsilon gamma nbeta llc_mean llc_std loss_mean loss_std
# 0 1e-06 10.0 57.6 2.31 0.12 0.041 0.003
# 1 1e-06 132.5 57.6 2.45 0.09 0.039 0.002
# ...
```
All options supported by `LLCEstimator.estimate_llc` (multi-GPU, compilation, custom loss, `unpack_fn`, seed) pass through to each grid point.
## Performance by default
LLC estimation is compute-intensive: it requires thousands of forward+backward passes across many parallel chains. This library is built around the principle that sensible defaults should leave no performance on the table.
| Feature | What it does |
|---|---|
| **`vmap` over chains** | All chains in a batch run as a single fused kernel with no Python loop and no per-chain overhead |
| **`torch.compile`** | The vmapped grad function is JIT-compiled by default (`compile=True`) |
| **Autocast inside the transform** | Autocast wraps the forward pass *before* `vmap`/`grad` see it, so the compiler can fuse dtype casts rather than treating them as opaque boundaries |
| **DataLoader warnings** | The estimator warns when `pin_memory=False`, `num_workers=0`, or `persistent_workers=False`, all of which cause avoidable stalls when the dataloader is cycled for the full SGLD run |
### Recommended DataLoader settings for GPU runs
```python
loader = DataLoader(
dataset,
batch_size=512,
shuffle=False, # non-shuffled for deterministic LLC; shuffle for training
pin_memory=True, # faster CPU→GPU transfer
num_workers=4, # overlap data loading with GPU compute
persistent_workers=True, # avoid worker respawn between SGLD epochs
)
```
## Hyperparameter guide
| Parameter | Typical range | Effect |
|---|---|---|
| `nbeta` | `n / log(n)` | Inverse temperature; scales the LLC estimate. Use `nbeta_from_effective_size(n)` or compute directly. |
| `learning_rate` | `1e-6` – `1e-4` | SGLD step size. Too large → divergence; too small → chain doesn't move. |
| `localization` | `1` – `1000` | Elastic pull toward initial weights. Higher values keep chains near $w^{\ast}$, giving tighter LLC estimates. |
| `burnin_steps` | 100 – 500 | Steps discarded before draws. Should cover transient behaviour. |
| `draws` | 100 – 500 | Samples per chain used to estimate $\bar L$. More draws → lower variance. |
| `chains` | 5 – 20 | Independent chains. More chains → lower variance; all run in parallel via vmap. |
## How it works
All chain parameters are stacked into a single batched tensor and a single functional forward+backward is vmapped over them, mapping naturally onto GPU parallelism. The statistical procedure follows the SGLD-based LLC estimator from Singular Learning Theory. Results can be validated against known closed-form RLCTs (see `test_known_rlct.py`).
This library was built on top of ideas from [devinterp](https://github.com/timaeus-research/devinterp).