Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/rlouf/mcx
Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
https://github.com/rlouf/mcx
probabilistic-programming
Last synced: 3 days ago
JSON representation
Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX.
- Host: GitHub
- URL: https://github.com/rlouf/mcx
- Owner: rlouf
- License: apache-2.0
- Created: 2020-01-22T08:38:43.000Z (about 5 years ago)
- Default Branch: master
- Last Pushed: 2024-03-20T15:48:42.000Z (11 months ago)
- Last Synced: 2024-10-12T16:44:16.526Z (4 months ago)
- Topics: probabilistic-programming
- Language: Python
- Homepage: https://rlouf.github.io/mcx
- Size: 882 KB
- Stars: 324
- Watchers: 17
- Forks: 17
- Open Issues: 19
-
Metadata Files:
- Readme: README.md
- Contributing: CONTRIBUTING.md
- License: LICENSE
- Code of conduct: CODE_OF_CONDUCT.md
Awesome Lists containing this project
- awesome-jax - mcx - Express & compile probabilistic programs for performant inference. <img src="https://img.shields.io/github/stars/rlouf/mcx?style=social" align="center"> (Libraries / New Libraries)
- awesome-jax - mcx - Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX. <img src="https://img.shields.io/github/stars/rlouf/mcx?style=social" align="center"> (Libraries / Inactive Libraries)
- awesome-jax - mcx - Express & compile probabilistic programs for performant inference on CPU & GPU. Powered by JAX. <img src="https://img.shields.io/github/stars/rlouf/mcx?style=social" align="center"> (Libraries / Inactive Libraries)
README
MCX
XLA-rated Bayesian inferenceMCX is a probabilistic programming library with a laser-focus on sampling
methods. MCX transforms the model definitions to generate logpdf or sampling
functions. These functions are JIT-compiled with JAX; they support batching and
can be exectuted on CPU, GPU or TPU transparently.The project is currently at its infancy and a moonshot towards providing
sequential inference as a first-class citizen, and performant sampling methods
for Bayesian deep learning.MCX's philosophy
1. Knowing how to express a graphical model and manipulating Numpy arrays should
be enough to define a model.
2. Models should be modular and re-usable.
3. Inference should be performant and should leverage GPUs.See the [documentation](https://rlouf.github.io/mcx) for more information. See [this issue](https://github.com/rlouf/mcx/issues/1) for an updated roadmap for v0.1.
## Current API
Note that there are still many moving pieces in `mcx` and the API may change
slightly.```python
import arviz as az
import jax
import jax.numpy as jnp
import numpy as npimport mcx
from mcx.distributions import Exponential, Normal
from mcx.inference import HMCrng_key = jax.random.PRNGKey(0)
x_data = np.random.normal(0, 5, size=(1000,1))
y_data = 3 * x_data + np.random.normal(size=x_data.shape)@mcx.model
def linear_regression(x, lmbda=1.):
scale <~ Exponential(lmbda)
coefs <~ Normal(jnp.zeros(jnp.shape(x)[-1]), 1)
preds <~ Normal(jnp.dot(x, coefs), scale)
return preds
prior_predictive = mcx.prior_predict(rng_key, linear_regression, (x_data,))posterior = mcx.sampler(
rng_key,
linear_regression,
(x_data,),
{'preds': y_data},
HMC(100),
).run()az.plot_trace(posterior)
posterior_predictive = mcx.posterior_predict(rng_key, linear_regression, (x_data,), posterior)
```## MCX's future
We are currently considering the future directions:
- **Neural network layers:** You can follow discussions about the API in [this Pull Request](https://github.com/rlouf/mcx/pull/16).
- **Programs with stochastic support:** Discussion in this [Issue](https://github.com/rlouf/mcx/issues/37).
- **Tools for causal inference:** Made easier by the internal representation as a
graph.You are more than welcome to contribute to these discussions, or suggest
potential future directions.## Linear sampling
Like most PPL, MCX implements a batch sampling runtime:
```python
sampler = mcx.sampler(
rng_key,
linear_regression,
*args,
observations,
kernel,
)posterior = sampler.run()
```The warmup trace is discarded by default but you can obtain it by running:
```python
warmup_posterior = sampler.warmup()
posterior = sampler.run()
```You can extract more samples from the chain after a run and combine the
two traces:```python
posterior += sampler.run()
```By default MCX will sample in interactive mode using a python `for` loop and
display a progress bar and various diagnostics. For faster sampling you can use:```python
posterior = sampler.run(compile=True)
```One could use the combination in a notebook to first get a lower bound on the
sampling rate before deciding on a number of samples.### Interactive sampling
Sampling the posterior is an iterative process. Yet most libraries only provide
batch sampling. The generator runtime is already implemented in `mcx`, which
opens many possibilities such as:- Dynamical interruption of inference (say after getting a set number of
effective samples);
- Real-time monitoring of inference with something like tensorboard;
- Easier debugging.```python
samples = mcx.sampler(
rng_key,
linear_regression,
*args,
observations,
kernel,
)trace = mcx.Trace()
for sample in samples:
trace.append(sample)iter(sampler)
next(sampler)
```Note that the performance of the interactive mode is significantly lower than
that of the batch sampler. However, both can be used successively:```python
trace = mcx.Trace()
for i, sample in enumerate(samples):
print(do_something(sample))
trace.append(sample)
if i % 10 == 0:
trace += sampler.run(100_000, compile=True)
```## Important note
MCX takes a lot of inspiration from other probabilistic programming languages
and libraries: Stan (NUTS and the very knowledgeable community), PyMC3 (for its
simple API), Tensorflow Probability (for its shape system and inference
vectorization), (Num)Pyro (for the use of JAX in the backend), Gen.jl and
Turing.jl (for composable inference), Soss.jl (generative model API), Anglican,
and many that I forget.