{"id":34102856,"url":"https://github.com/larslorch/stadion","last_synced_at":"2026-03-10T07:31:50.210Z","repository":{"id":222863510,"uuid":"758578846","full_name":"larslorch/stadion","owner":"larslorch","description":"Causal Modeling with Stationary Diffusions, AISTATS 2024","archived":false,"fork":false,"pushed_at":"2025-03-03T08:46:02.000Z","size":384,"stargazers_count":20,"open_issues_count":0,"forks_count":3,"subscribers_count":2,"default_branch":"main","last_synced_at":"2025-12-16T23:05:21.312Z","etag":null,"topics":["causal-inference","causality","dynamical-systems","kernel-methods","sde"],"latest_commit_sha":null,"homepage":"https://arxiv.org/abs/2310.17405","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/larslorch.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":null,"code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null,"notice":null,"maintainers":null,"copyright":null,"agents":null,"dco":null,"cla":null}},"created_at":"2024-02-16T16:12:20.000Z","updated_at":"2025-11-22T19:18:26.000Z","dependencies_parsed_at":"2025-12-14T17:03:50.354Z","dependency_job_id":null,"html_url":"https://github.com/larslorch/stadion","commit_stats":null,"previous_names":["larslorch/stadion"],"tags_count":0,"template":false,"template_full_name":null,"purl":"pkg:github/larslorch/stadion","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/larslorch%2Fstadion","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/larslorch%2Fstadion/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/larslorch%2Fstadion/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/larslorch%2Fstadion/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/larslorch","download_url":"https://codeload.github.com/larslorch/stadion/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/larslorch%2Fstadion/sbom","scorecard":null,"host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":286080680,"owners_count":30326908,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2026-03-10T05:25:20.737Z","status":"ssl_error","status_checked_at":"2026-03-10T05:25:17.430Z","response_time":106,"last_error":"SSL_read: unexpected eof while reading","robots_txt_status":"success","robots_txt_updated_at":"2025-07-24T06:49:26.215Z","robots_txt_url":"https://github.com/robots.txt","online":false,"can_crawl_api":true,"host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["causal-inference","causality","dynamical-systems","kernel-methods","sde"],"created_at":"2025-12-14T17:03:24.806Z","updated_at":"2026-03-10T07:31:50.167Z","avatar_url":"https://github.com/larslorch.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# Causal Modeling with Stationary Diffusions\n\n[![PyPi](https://img.shields.io/pypi/v/stadion?logo=PyPI)](https://pypi.org/project/stadion/)\n\nThis is the Python package for \n*\"Causal Modeling with Stationary Diffusions\"*\n([Lorch et al., 2024](https://arxiv.org/abs/2310.17405)).\nTo install the latest version, run:\n```bash\npip install stadion\n```\nThe `stadion` package allows learning **stationary** systems of\n**stochastic differential equations (SDEs)**, whose stationary densities\nmatch the empirical distribution of a target dataset.\nThe target dataset for learning the SDEs\ncontains i.i.d. samples from the stationary density, \nnot a time series.\nPut differently, we perform system identification \nfrom the stationary distribution.\nWhen provided with [several datasets](#multiple-interventional-datasets) \n(e.g., from different experimental conditions),\nthe algorithm learns one SDE model that fits all observed distributions\nusing jointly-learned intervention parameters that perturb \nthe SDE model. \n\n\nThe objective for learning the SDE parameters is the \n**kernel deviation from stationarity (KDS)**.\nThe KDS depends on the SDEs and a kernel function, and \nits sample approximation is computed using only the target dataset. \nHence, optimizing the KDS does not require rolling-out trajectories \nfrom the SDE model or backpropagating gradients through time.\nThe SDE drift and diffusion functions can \nbe **arbitrary nonlinear, differentiable functions**.\nThis package also provides the KDS as a stand-alone [loss function](#kds-loss-function).\n\nOur implementation leverages efficient vectorization, auto-diff, \nJIT compilation, and (multi-device) hardware acceleration \nwith [JAX](https://github.com/google/jax). \n\n\n## Quick Start\n\nThe following code demonstrates how to use the `stadion` package. \nIn this example, we use the KDS to learn a linear stationary SDE model from \na dataset sampled from a Gaussian distribution.\n```python\nfrom jax import random\nfrom stadion.models import LinearSDE\n\nkey = random.PRNGKey(0)\nn, d = 1000, 5\n\n# generate a dataset\nkey, subk = random.split(key)\nw = random.normal(subk, shape=(d, d))\n\nkey, subk = random.split(key)\ndata = random.normal(subk, shape=(n, d)) @ w\n\n# fit stationary diffusion model\nmodel = LinearSDE()\nkey, subk = random.split(key)\nmodel.fit(subk, data)\n\n# sample from model and get parameters\nkey, subk = random.split(key)\nx_pred = model.sample(subk, 100)\nparams = model.param\n```\nCurrently, the following SDE model classes are implemented in `stadion.models`:\n\n- [`LinearSDE`](stadion/models/linear.py)\n- [`MLPSDE`](stadion/models/mlp.py)\n\nThe `MLPSDE` model is a generalization of the `LinearSDE` model to\nnonlinear drift functions.\no support the inference functionality in the code snippet above, \nnew model classes have to inherit from\n[`SDE`](stadion/sde.py) and [`KDSMixin`](stadion/inference.py)\nand implement the methods decorated with `@abstractmethod`\nlike `LinearSDE` and `MLPSDE`.\n\n## Additional Examples\n\n### KDS loss function\n\nThe `stadion` package provides the KDS as an \noff-the-shelf loss function.\nIn the below, we define custom SDE functions `f` and `sigma`\nand a kernel `k` and use [`kds_loss`](stadion/kds.py) to create the\ncorresponding loss function and its gradient with respect to the parameters of `f` and `sigma`.\nThis may be useful when using the KDS loss in\ncustom implementations that do not subclass from \n[`SDE`](stadion/sde.py) and [`KDSMixin`](stadion/inference.py).\nHere, `f` and `sigma` can be arbitrary differentiable, possibly\nnonlinear, functions.\n\n\n```python\n...\n\nfrom jax import numpy as jnp, value_and_grad\nfrom stadion import kds_loss\n\n# SDE functions\ndef f(x, param):\n    return param[\"w\"] @ x + param[\"b\"]\n\ndef sigma(x, param):\n    return jnp.exp(param[\"c\"]) * jnp.eye(d)\n\n# kernel\ndef k(x, y):\n    return jnp.exp(- jnp.square(x - y).sum(-1) / 100)\n\n# create KDS loss function\nloss_fun = kds_loss(f, sigma, k)\n\n# compute loss and parameter gradient for dataset and a parameter setting\nkey, *subk = random.split(key, 4)\nparams = {\n    \"w\": random.normal(subk[0], shape=(d, d)),\n    \"b\": random.normal(subk[1], shape=(d,)),\n    \"c\": random.normal(subk[2], shape=(d,)),\n}\n\nloss, dparams = value_and_grad(loss_fun, argnums=1)(data, params)\n```\n\n### Multiple Interventional Datasets\n\nProvided multiple datasets, \nthe algorithm jointly learns one causal SDE model with \nseparate intervention parameters for each dataset.\nThe intervention parameters are used to\nfit all observed distributions through interventions \nin the shared SDE model.\nBelow, we add two interventional datasets and assume we know they \nintervened on the variables 2 and 4, respectively, which \nrestricts the learnable intervention parameters to these variables.\n\n\n```python\n...\n\n# sample two more datasets with shift interventions\na, targets_a =  3, jnp.array([0, 1, 0, 0, 0])\nb, targets_b = -5, jnp.array([0, 0, 0, 1, 0])\n\nkey, subk_0, subk_1 = random.split(key, 3)\ndata_a = (random.normal(subk_0, shape=(n, d)) + a * targets_a) @ w\ndata_b = (random.normal(subk_1, shape=(n, d)) + b * targets_b) @ w\n\n# fit stationary diffusion model\nmodel = LinearSDE()\nkey, subk = random.split(key)\nmodel.fit(\n    subk,\n    [data, data_a, data_b],\n    targets=[jnp.zeros(d), targets_a, targets_b],\n)\n\n# get inferred model and intervention parameters\nparam = model.param\nintv_param = model.intv_param\n\n# sample from model under intervention parameters learned for 1st environment\nintv_param_a = intv_param.index_at(1)\nx_pred_a = model.sample(subk, 100, intv_param=intv_param_a)\n```\n\n\n## Custom Installation and Branches\n\nThe latest release is published on PyPI, \nso the best way to install `stadion` is using `pip`\nas explained above.\nFor custom installations, we recommend using `conda` and generating a new environment \nvia `conda env create --file environment.yaml`.\n\nThe repository consists of two branches:\n- `main` (recommended): Lightweight and easy-to-use package for using `stadion` in your research or applications.\n- `aistats`: Code to reproduce the results in [Lorch et al. (2024)](https://arxiv.org/abs/2310.17405). \nThe purpose of this branch is reproducibility; the branch is not updated anymore and may contain outdated notation and documentation.\n\n## Reference\n\n```\n@inproceedings{lorch2024causal,\n  title={Causal Modeling with Stationary Diffusions},\n  author={Lorch, Lars and Krause, Andreas and Sch{\\\"o}lkopf, Bernhard},\n  booktitle={International Conference on Artificial Intelligence and Statistics},\n  pages={1927--1935},\n  year={2024},\n  organization={PMLR}\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flarslorch%2Fstadion","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Flarslorch%2Fstadion","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flarslorch%2Fstadion/lists"}