{"id":13807008,"url":"https://github.com/rlouf/mcx","last_synced_at":"2025-04-05T00:10:42.685Z","repository":{"id":48319712,"uuid":"235531205","full_name":"rlouf/mcx","owner":"rlouf","description":"Express \u0026 compile probabilistic programs for performant inference on CPU \u0026 GPU. Powered by JAX.","archived":false,"fork":false,"pushed_at":"2024-03-20T15:48:42.000Z","size":903,"stargazers_count":324,"open_issues_count":19,"forks_count":17,"subscribers_count":17,"default_branch":"master","last_synced_at":"2024-10-12T16:44:16.526Z","etag":null,"topics":["probabilistic-programming"],"latest_commit_sha":null,"homepage":"https://rlouf.github.io/mcx","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/rlouf.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":"CONTRIBUTING.md","funding":null,"license":"LICENSE","code_of_conduct":"CODE_OF_CONDUCT.md","threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2020-01-22T08:38:43.000Z","updated_at":"2024-10-04T07:08:40.000Z","dependencies_parsed_at":"2024-08-04T01:06:57.576Z","dependency_job_id":"41f4ebb5-e4a9-43ad-a1df-715c90a34faa","html_url":"https://github.com/rlouf/mcx","commit_stats":{"total_commits":370,"total_committers":12,"mean_commits":"30.833333333333332","dds":"0.12702702702702706","last_synced_commit":"26c316f2911dac86fbc585b66a8652872187f64e"},"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/rlouf%2Fmcx","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/rlouf%2Fmcx/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/rlouf%2Fmcx/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/rlouf%2Fmcx/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/rlouf","download_url":"https://codeload.github.com/rlouf/mcx/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":247266565,"owners_count":20910836,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["probabilistic-programming"],"created_at":"2024-08-04T01:01:19.299Z","updated_at":"2025-04-05T00:10:42.669Z","avatar_url":"https://github.com/rlouf.png","language":"Python","readme":"\u003ch2 align=\"center\"\u003e\n  MCX\n\u003c/h2\u003e\n\n\u003ch3 align=\"center\"\u003e\n XLA-rated Bayesian inference\n\u003c/h3\u003e\n\nMCX is a probabilistic programming library with a laser-focus on sampling\nmethods. MCX transforms the model definitions to generate logpdf or sampling\nfunctions. These functions are JIT-compiled with JAX; they support batching and\ncan be exectuted on CPU, GPU or TPU transparently.\n\nThe project is currently at its infancy and a moonshot towards providing\nsequential inference as a first-class citizen, and performant sampling methods\nfor Bayesian deep learning.\n\nMCX's philosophy\n\n1. Knowing how to express a graphical model and manipulating Numpy arrays should\n   be enough to define a model.\n2. Models should be modular and re-usable.\n3. Inference should be performant and should leverage GPUs.\n\nSee the [documentation](https://rlouf.github.io/mcx) for more information. See [this issue](https://github.com/rlouf/mcx/issues/1) for an updated roadmap for v0.1.\n\n## Current API\n\nNote that there are still many moving pieces in `mcx` and the API may change\nslightly.\n\n```python\nimport arviz as az\nimport jax\nimport jax.numpy as jnp\nimport numpy as np\n\nimport mcx\nfrom mcx.distributions import Exponential, Normal\nfrom mcx.inference import HMC\n\nrng_key = jax.random.PRNGKey(0)\n\nx_data = np.random.normal(0, 5, size=(1000,1))\ny_data = 3 * x_data + np.random.normal(size=x_data.shape)\n\n@mcx.model\ndef linear_regression(x, lmbda=1.):\n    scale \u003c~ Exponential(lmbda)\n    coefs \u003c~ Normal(jnp.zeros(jnp.shape(x)[-1]), 1)\n    preds \u003c~ Normal(jnp.dot(x, coefs), scale)\n    return preds\n    \nprior_predictive = mcx.prior_predict(rng_key, linear_regression, (x_data,))\n\nposterior = mcx.sampler(\n    rng_key,\n    linear_regression,\n    (x_data,),\n    {'preds': y_data},\n    HMC(100),\n).run()\n\naz.plot_trace(posterior)\n\nposterior_predictive = mcx.posterior_predict(rng_key, linear_regression, (x_data,), posterior)\n```\n\n## MCX's future\n\nWe are currently considering the future directions:\n\n- **Neural network layers:** You can follow discussions about the API in [this Pull Request](https://github.com/rlouf/mcx/pull/16).\n- **Programs with stochastic support:** Discussion in this [Issue](https://github.com/rlouf/mcx/issues/37).\n- **Tools for causal inference:** Made easier by the internal representation as a\n  graph.\n\nYou are more than welcome to contribute to these discussions, or suggest\npotential future directions.\n\n\n## Linear sampling\n\nLike most PPL, MCX implements a batch sampling runtime:\n\n```python\nsampler = mcx.sampler(\n    rng_key,\n    linear_regression,\n    *args,\n    observations,\n    kernel,\n)\n\nposterior = sampler.run()\n```\n\nThe warmup trace is discarded by default but you can obtain it by running:\n\n```python\nwarmup_posterior = sampler.warmup()\nposterior = sampler.run()\n```\n\nYou can extract more samples from the chain after a run and combine the\ntwo traces:\n\n```python\nposterior += sampler.run()\n```\n\nBy default MCX will sample in interactive mode using a python `for` loop and\ndisplay a progress bar and various diagnostics. For faster sampling you can use:\n\n```python\nposterior = sampler.run(compile=True)\n```\n\nOne could use the combination in a notebook to first get a lower bound on the\nsampling rate before deciding on a number of samples.\n\n\n### Interactive sampling\n\nSampling the posterior is an iterative process. Yet most libraries only provide\nbatch sampling. The generator runtime is already implemented in `mcx`, which\nopens many possibilities such as:\n\n- Dynamical interruption of inference (say after getting a set number of\n  effective samples);\n- Real-time monitoring of inference with something like tensorboard;\n- Easier debugging.\n\n```python\nsamples = mcx.sampler(\n    rng_key,\n    linear_regression,\n    *args,\n    observations,\n    kernel,\n)\n\ntrace = mcx.Trace()\nfor sample in samples:\n  trace.append(sample)\n\niter(sampler)\nnext(sampler)\n```\n\nNote that the performance of the interactive mode is significantly lower than\nthat of the batch sampler. However, both can be used successively:\n\n```python\ntrace = mcx.Trace()\nfor i, sample in enumerate(samples):\n  print(do_something(sample))\n  trace.append(sample)\n  if i % 10 == 0:\n    trace += sampler.run(100_000, compile=True)\n```\n\n## Important note\n\nMCX takes a lot of inspiration from other probabilistic programming languages\nand libraries: Stan (NUTS and the very knowledgeable community), PyMC3 (for its\nsimple API), Tensorflow Probability (for its shape system and inference\nvectorization), (Num)Pyro (for the use of JAX in the backend), Gen.jl and\nTuring.jl (for composable inference), Soss.jl (generative model API), Anglican,\nand many that I forget.\n","funding_links":[],"categories":["Python","Libraries"],"sub_categories":["New Libraries","Inactive Libraries"],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Frlouf%2Fmcx","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Frlouf%2Fmcx","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Frlouf%2Fmcx/lists"}