{"id":13478296,"url":"https://github.com/Joshuaalbert/jaxns","last_synced_at":"2025-03-27T07:31:02.735Z","repository":{"id":38337247,"uuid":"284098897","full_name":"Joshuaalbert/jaxns","owner":"Joshuaalbert","description":"Probabilistic Programming and Nested sampling in JAX","archived":false,"fork":false,"pushed_at":"2024-05-15T10:21:36.000Z","size":34673,"stargazers_count":125,"open_issues_count":9,"forks_count":8,"subscribers_count":5,"default_branch":"main","last_synced_at":"2024-05-16T14:20:20.558Z","etag":null,"topics":["probabilistic-programming"],"latest_commit_sha":null,"homepage":"https://jaxns.readthedocs.io/","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"other","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/Joshuaalbert.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2020-07-31T18:00:18.000Z","updated_at":"2024-05-29T16:24:28.463Z","dependencies_parsed_at":"2023-10-02T21:08:58.837Z","dependency_job_id":"5b0dfd9d-9721-4ad8-a179-60fbb3024a12","html_url":"https://github.com/Joshuaalbert/jaxns","commit_stats":{"total_commits":387,"total_committers":3,"mean_commits":129.0,"dds":0.2661498708010336,"last_synced_commit":"ed1cca7554d3f3517adf7a18824446bd072a65da"},"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/Joshuaalbert%2Fjaxns","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/Joshuaalbert%2Fjaxns/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/Joshuaalbert%2Fjaxns/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/Joshuaalbert%2Fjaxns/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/Joshuaalbert","download_url":"https://codeload.github.com/Joshuaalbert/jaxns/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":245802316,"owners_count":20674634,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["probabilistic-programming"],"created_at":"2024-07-31T16:01:55.162Z","updated_at":"2025-03-27T07:30:57.712Z","avatar_url":"https://github.com/Joshuaalbert.png","language":"Python","funding_links":[],"categories":["Python","Models and Projects","Libraries"],"sub_categories":["JAX"],"readme":"[![Python](https://img.shields.io/pypi/pyversions/jaxns.svg)](https://badge.fury.io/py/jaxns)\n[![PyPI](https://badge.fury.io/py/jaxns.svg)](https://badge.fury.io/py/jaxns)\n[![Documentation Status](https://readthedocs.org/projects/jaxns/badge/?version=latest)](https://jaxns.readthedocs.io/en/latest/?badge=latest)\n\nMain\nStatus: ![Workflow name](https://github.com/JoshuaAlbert/jaxns/actions/workflows/unittests.yml/badge.svg?branch=main)\n\nDevelop\nStatus: ![Workflow name](https://github.com/JoshuaAlbert/jaxns/actions/workflows/unittests.yml/badge.svg?branch=develop)\n\n![JAXNS](https://github.com/JoshuaAlbert/jaxns/raw/main/jaxns_logo.png)\n\n## Mission: _To make nested sampling **faster, easier, and more powerful**_\n\n# What is it?\n\nJAXNS is:\n\n1) a probabilistic programming framework using nested sampling as the engine;\n2) coded in JAX in a manner that allows lowering the entire inference algorithm to XLA primitives, which are\n   JIT-compiled for high performance;\n3) continuously improving on its mission of making nested sampling faster, easier, and more powerful; and\n4) citable, use the [(old) pre-print here](https://arxiv.org/abs/2012.15286).\n\n## JAXNS Probabilistic Programming Framework\n\nJAXNS provides a powerful JAX-based probabilistic programming framework, which allows you to define probabilistic\nmodels easily, and use them for advanced purposes. Probabilistic models can have both Bayesian and parameterised\nvariables.\nBayesian variables are random variables, and are sampled from a prior distribution.\nParameterised variables are point-wise representations of a prior distribution, and are thus not random.\nAssociated with them is the log-probability of the prior distribution at that point.\n\nLet's break apart an example of a simple probabilistic model. Note, this example can also be followed\nin [docs/examples/intro_example.ipynb](docs/examples/intro_example.ipynb).\n\n### Defining a probabilistic model\n\nPrior models are functions that produce generators of `Prior` objects.\nThe function must eventually return the inputs to the likelihood function.\nThe returned values of a yielded `Prior` is a simple JAX array, i.e. you can do anything you want to it with JAX ops.\nThe rules of static programming apply, i.e. you cannot dynamically allocate arrays.\n\nJAXNS makes use of the Tensorflow Probability library for defining prior distributions, thus you can use __almost__\nany of the TFP distributions. You can also use any of the TFP bijectors to define transformed distributions.\n\nDistributions do have some requirements to be valid for use in JAXNS.\n\n1. They must have a quantile function, i.e. `dist.quantile(dist.cdf(x)) == x`.\n2. They must have a `log_prob` method that returns the log-probability of the distribution at a given value.\n\nMost of the TFP distributions satisfy these requirements.\n\nJAXNS has some special priors defined that can't be defined from TFP, see `jaxns.framework.special_priors`. You can\nalways request more if you need them.\n\nPrior variables __may__ be named but don't have to be. If they are named then they can be collected later via a\ntransformation, otherwise they are deemed hidden variables.\n\nThe output values of prior models are the inputs to the likelihood function. They can be PyTree's,\ne.g. `typing.NamedTuple`'s.\n\nFinally, priors can become point-wise estimates of the prior distribution, by calling `parametrised()`. This turns a\nBayesian variable into a parameterised variable, e.g. one which can be used in optimisation.\n\n```python\nimport jax\nimport tensorflow_probability.substrates.jax as tfp\n\ntfpd = tfp.distributions\n\nfrom jaxns.framework.model import Model\nfrom jaxns.framework.prior import Prior\n\n\ndef prior_model():\n    mu = yield Prior(tfpd.Normal(loc=0., scale=1.))\n    # Let's make sigma a parameterised variable\n    sigma = yield Prior(tfpd.Exponential(rate=1.), name='sigma').parametrised()\n    x = yield Prior(tfpd.Cauchy(loc=mu, scale=sigma), name='x')\n    uncert = yield Prior(tfpd.Exponential(rate=1.), name='uncert')\n    return x, uncert\n\n\ndef log_likelihood(x, uncert):\n    return tfpd.Normal(loc=0., scale=uncert).log_prob(x)\n\n\nmodel = Model(prior_model=prior_model, log_likelihood=log_likelihood)\n\n# You can sanity check the model (always a good idea when exploring)\nmodel.sanity_check(key=jax.random.PRNGKey(0), S=100)\n\n# The size of the Bayesian part of the prior space is `model.U_ndims`.\n```\n\n### Sampling and transforming variables\n\nThere are two spaces of samples:\n\n1. U-space: samples in base measure space, and is dimensionless, or rather has units of probability.\n2. X-space: samples in the space of the model, and has units of the prior variable.\n\n```python\n# Sample the prior in U-space (base measure)\nU = model.sample_U(key=jax.random.PRNGKey(0))\n# Transform to X-space\nX = model.transform(U=U)\n# Only named Bayesian prior variables are returned, the rest are treated as hidden variables.\nassert set(X.keys()) == {'x', 'uncert'}\n\n# Get the return value of the prior model, i.e. the input to the likelihood\nx_sample, uncert_sample = model.prepare_input(U=U)\n```\n\n### Computing log-probabilities\n\nAll computations are based on the U-space variables.\n\n```python\n# Evaluate different parts of the model\nlog_prob_prior = model.log_prob_prior(U)\nlog_prob_likelihood = model.log_prob_likelihood(U, allow_nan=False)\nlog_prob_joint = model.log_prob_joint(U, allow_nan=False)\n```\n\n### Computing gradients of the joint probability w.r.t. parameters\n\n```python\ninit_params = model.params\n\n\ndef log_prob_joint_fn(params, U):\n    # Calling model with params returns a new model with the params set\n    return model(params).log_prob_joint(U, allow_nan=False)\n\n\nvalue, grad = jax.value_and_grad(log_prob_joint_fn)(init_params, U)\n```\n\n## Nested Sampling Engine\n\nGiven a probabilistic model, JAXNS can perform nested sampling on it. This allows computing the Bayesian evidence and\nposterior samples.\n\n```python\nfrom jaxns import NestedSampler\n\nns = NestedSampler(model=model, max_samples=1e5)\n\n# Run the sampler\ntermination_reason, state = ns(jax.random.PRNGKey(42))\n# Get the results\nresults = ns.to_results(termination_reason=termination_reason, state=state)\n```\n\n#### To AOT or JIT-compile the sampler\n\n```python\n# Ahead of time compilation (sometimes useful)\nns_aot = jax.jit(ns).lower(jax.random.PRNGKey(42)).compile()\n\n# Just-in-time compilation (usually useful)\nns_jit = jax.jit(ns)\n```\n\nYou can inspect the results, and plot them.\n\n```python\nfrom jaxns import summary, plot_diagnostics, plot_cornerplot, save_results, load_results\n\n# Optionally save the results to file\nsave_results(results, 'results.json')\n# To load the results back use this\nresults = load_results('results.json')\n\nsummary(results)\nplot_diagnostics(results)\nplot_cornerplot(results)\n```\n\nOutput:\n\n```\n--------\nTermination Conditions:\nSmall remaining evidence\n--------\nlikelihood evals: 149918\nsamples: 3780\nphantom samples: 1710\nlikelihood evals / sample: 39.7\nphantom fraction (%): 45.2%\n--------\nlogZ=-1.65 +- 0.15\nH=-1.13\nESS=132\n--------\nuncert: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.\nuncert: 0.68 +- 0.58 | 0.13 / 0.48 / 1.37 | 0.0 | 0.0\n--------\nx: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.\nx: 0.07 +- 0.62 | -0.57 / 0.06 / 0.73 | 0.0 | 0.0\n--------\n```\n\n![](docs/examples/intro_diagnostics.png)\n![](docs/examples/intro_cornerplot.png)\n\n### Using the posterior samples\n\nNested sampling produces weighted posterior samples. To use for most use cases, you can simply resample (with\nreplacement).\n\n```python\nfrom jaxns import resample\n\nsamples = resample(\n    key=jax.random.PRNGKey(0),\n    samples=results.samples,\n    log_weights=results.log_dp_mean,\n    S=1000,\n    replace=True\n)\n```\n\n### Maximising the evidence\n\nThe Bayesian evidence is the ultimate model selection density, and choosing a model that maximises the evidence is\nthe best way to select a model. We can use the evidence maximisation algorithm to optimise the parametrised variables\nof the model, in the manner that maximises the evidence. Below `EvidenceMaximisation` does this for the model we defined\nabove, where the parametrised variables are\nautomatically constrained to be in the right range, and numerical stability is ensured with proper scaling.\n\nWe see that the evidence maximisation chooses a `sigma` the is very small.\n\n```python\nfrom jaxns.experimental import EvidenceMaximisation\n\n# Let's train the sigma parameter to maximise the evidence\n\nem = EvidenceMaximisation(model, ns_kwargs=dict(max_samples=1e4))\nresults, params = em.train(num_steps=5)\n\nsummary(results, with_parametrised=True)\n```\n\nOutput:\n\n```\n--------\nTermination Conditions:\nSmall remaining evidence\n--------\nlikelihood evals: 72466\nsamples: 1440\nphantom samples: 0\nlikelihood evals / sample: 50.3\nphantom fraction (%): 0.0%\n--------\nlogZ=-1.119 +- 0.098\nH=-0.93\nESS=241\n--------\nsigma: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.\nsigma: 5.40077599e-05 +- 3.6e-12 | 5.40077563e-05 / 5.40077563e-05 / 5.40077563e-05 | 5.40077563e-05 | 5.40077563e-05\n--------\nuncert: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.\nuncert: 0.6 +- 0.54 | 0.05 / 0.45 / 1.37 | 0.0 | 0.0\n--------\nx: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est.\nx: 0.01 +- 0.56 | -0.6 / -0.0 / 0.69 | 0.0 | -0.0\n--------\n```\n\n# Documentation\n\nYou can read the documentation [here](https://jaxns.readthedocs.io/en/latest/#). In addition, JAXNS is partially\ndescribed in the\n[original paper](https://arxiv.org/abs/2012.15286), as well as the paper on [Phantom-Powered Nested\nSampling paper](https://arxiv.org/abs/2312.11330).\n\n# Install\n\n**Notes:**\n\n1. JAXNS requires \u003e= Python 3.9. It is always highly recommended to use the latest version of Python.\n2. It is always highly recommended to use a unique virtual environment for each project.\n   To use **miniconda**, ensure it is installed on your system, then run the following commands:\n\n```bash\n# To create a new env, if necessary\nconda create -n jaxns_py python=3.12\nconda activate jaxns_py\n```\n\n## For end users\n\nInstall directly from PyPi,\n\n```bash\npip install jaxns\n```\n\n## For development\n\nClone repo `git clone https://www.github.com/JoshuaAlbert/jaxns.git`, and install:\n\n```bash\ncd jaxns\npip install -r requirements.txt\npip install -r requirements-tests.txt\npip install -r requirements-examples.txt\npip install .\n```\n\n# Getting help and contributing examples\n\nDo you have a neat Bayesian problem, and want to solve it with JAXNS?\nI'm really encourage anyone in either the scientific community or industry to get involved and join the discussion\nforum.\nPlease use the [github discussion forum](https://github.com/JoshuaAlbert/jaxns/discussions) for getting help, or\ncontributing examples/neat use cases.\n\n# Quick start\n\nCheckout the examples [here](https://jaxns.readthedocs.io/en/latest/#).\n\n## Caveats\n\nThe caveat is that you need to be able to define your likelihood function with JAX. UPDATE: now you can just\nuse the `@jaxify_likelihood` decorator to run with arbitrary pythonic likelihoods.\n\n# Speed test comparison with other nested sampling packages\n\nJAXNS is really fast because it uses JAX.\nJAXNS is much faster than PolyChord, MultiNEST, and dynesty, typically achieving two to three orders of magnitude\nimprovement in run time, for models with cheap likelihood evaluations.\nThis is shown in (https://arxiv.org/abs/2012.15286).\n\nRecently JAXNS has implemented Phantom-Powered Nested Sampling, which helps for parameter inference. This is shown\nin (https://arxiv.org/abs/2312.11330).\n\n# Note on performance with parallelisation and GPUS\n\nTo use parallel computing, you can simply pass `devices` to the `NestedSampler` constructor. This will distributed\nsampling over the devices. To use GPUs you can pass `jax.devices('gpu')` to the `devices` argument. You can also se all\nyour CPUs by placing `os.environ[\"XLA_FLAGS\"] = f\"--xla_force_host_platform_device_count={os.cpu_count()}\"`\nbefore importing JAXNS.\n\n# Change Log\n\n25 Sep, 2024 -- JAXNS 2.6.2 released. Fixed some important (not so edge) cases. Made faster. Handle no seed scenarios.\n\n24 Sep, 2024 -- JAXNS 2.6.1 released. Sharded parallel JAXNS. Rewrite of internals to support sharded parallelisation.\n\n20 Aug, 2024 -- JAXNS 2.6.0 released. Removed haiku dependency. Implemented our own\ncontext. `jaxns.framework.context.convert_external_params` enables interfacing with any external NN libary.\n\n24 Jul, 2024 -- JAXNS 2.5.3 released. Replacing framework U-space with W-space. Maintained external API in U space.\n\n23 Jul, 2024 -- JAXNS 2.5.2 released. Added explicit density prior. Sped up parametrisation. Scan associative\nimplemented.\n\n27 May, 2024 -- JAXS 2.5.1 released. Fixed minor accuracy degradation introduced in 2.4.13.\n\n15 May, 2024 -- JAXNS 2.5.0 released. Added ability to handle non-JAX likelihoods, e.g. if you have a simulation\nframework with python bindings you can now use it for likelihoods in JAXNS. Small performance improvements.\n\n22 Apr, 2024 -- JAXNS 2.4.13 released. Fixes bug where slice sampling not invariant to monotonic transforms of\nlikelihood.\n\n20 Mar, 2024 -- JAXNS 2.4.12 released. Minor bug fixes, and readability improvements. Added Empirical special prior.\n\n5 Mar, 2024 -- JAXNS 2.4.11/b released. Add `random_init` to parametrised variables. Enable special priors to be\nparametrised.\n\n23 Feb, 2024 -- JAXNS 2.4.10 released. Hotfix for import error.\n\n21 Feb, 2024 -- JAXNS 2.4.9 released. Minor improvements to some priors, and bug fixes.\n\n31 Jan, 2024 -- JAXNS 2.4.8 released. Improved global optimisation performance using gradient slicing.\nImproved evidence maximisation.\n\n25 Jan, 2024 -- JAXNS 2.4.6/7 released. Added logging. Use L-BFGS for Evidence Maximisation M-step. Fix bug in finetune.\n\n24 Jan, 2024 -- JAXNS 2.4.5 released. Gradient based finetuning global optimisation using L-BFGS. Added ability to\nsimulate prior models without bulding model (for data generation.)\n\n15 Jan, 2024 -- JAXNS 2.4.4 released. Fix performance issue for larger `max_samples`. Fixed bug in termination\nconditions. Improved parallel performance.\n\n10 Jan, 2024 -- JAXNS 2.4.2/3 released. Another performance boost, and experimental global optimiser.\n\n9 Jan, 2024 -- JAXNS 2.4.1 released. Improve performance slightly for larger `max_samples`, still a performance issue.\n\n8 Jan, 2024 -- JAXNS 2.4.0 released. Python 3.9+ becomes supported. Migrate parametrised models to stable.\nAll models are now default able to be parametrised, so you can use hk.Parameter anywhere in the model.\n\n21 Dec, 2023 -- JAXNS 2.3.4 released. Correction for ESS and logZ uncert. `parameter_estimation` mode.\n\n20 Dec, 2023 -- JAXNS 2.3.2/3 released. Improved default parameters. `difficult_model` mode. Improve plotting.\n\n18 Dec, 2023 -- JAXNS 2.3.1 released. Paper open science release. Default parameters from paper.\n\n11 Dec, 2023 -- JAXNS 2.3.0 released. Release of Phantom-Powered Nested Sampling algorithm.\n\n5 Oct, 2023 -- JAXNS 2.2.6 released. Minor update to evidence maximisation.\n\n3 Oct, 2023 -- JAXNS 2.2.5 released. Parametrised priors, and evidence maximisation added.\n\n24 Sept, 2023 -- JAXNS 2.2.4 released. Add marginalising from saved U samples.\n\n28 July, 2023 -- JAXNS 2.2.3 released. Bug fix for singular priors.\n\n26 June, 2023 -- JAXNS 2.2.1 released. Multi-ellipsoidal sampler added back in. Adaptive refinement disabled, as a bias\nhas been detected in it.\n\n15 June, 2023 -- JAXNS 2.2.0 released. Added support to allow TFP bijectors to defined transformed distributions. Other\nminor improvements.\n\n15 April, 2023 -- JAXNS 2.1.0 released. pmap used on outer-most loops allowing efficient device-device communication\nduring parallel runs.\n\n8 March, 2023 -- JAXNS 2.0.1 released. Changed how we're doing annotations to support python 3.8 again.\n\n3 January, 2023 -- JAXNS 2.0 released. Complete overhaul of components. New way to build models.\n\n5 August, 2022 -- JAXNS 1.1.1 released. Pytree shaped priors.\n\n2 June, 2022 -- JAXNS 1.1.0 released. Dynamic sampling takes advantage of adaptive refinement. Parallelisation. Bayesian\nopt and global opt modules.\n\n30 May, 2022 -- JAXNS 1.0.1 released. Improvements to speed, parallelisation, and structure of code.\n\n9 April, 2022 -- JAXNS 1.0.0 released. Parallel sampling, dynamic search, and adaptive refinement. Global optimiser\nreleased.\n\n2 Jun, 2021 -- JAXNS 0.0.7 released.\n\n13 May, 2021 -- JAXNS 0.0.6 released.\n\n8 Mar, 2021 -- JAXNS 0.0.5 released.\n\n8 Mar, 2021 -- JAXNS 0.0.4 released.\n\n7 Mar, 2021 -- JAXNS 0.0.3 released.\n\n28 Feb, 2021 -- JAXNS 0.0.2 released.\n\n28 Feb, 2021 -- JAXNS 0.0.1 released.\n\n1 January, 2021 -- Paper submitted\n\n## Star History\n\n\u003ca href=\"https://star-history.com/#joshuaalbert/jaxns\u0026Date\"\u003e\n  \u003cpicture\u003e\n    \u003csource media=\"(prefers-color-scheme: dark)\" srcset=\"https://api.star-history.com/svg?repos=joshuaalbert/jaxns\u0026type=Date\u0026theme=dark\" /\u003e\n    \u003csource media=\"(prefers-color-scheme: light)\" srcset=\"https://api.star-history.com/svg?repos=joshuaalbert/jaxns\u0026type=Date\" /\u003e\n    \u003cimg alt=\"Star History Chart\" src=\"https://api.star-history.com/svg?repos=joshuaalbert/jaxns\u0026type=Date\" /\u003e\n  \u003c/picture\u003e\n\u003c/a\u003e\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2FJoshuaalbert%2Fjaxns","html_url":"https://awesome.ecosyste.ms/projects/github.com%2FJoshuaalbert%2Fjaxns","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2FJoshuaalbert%2Fjaxns/lists"}