{"id":13521041,"url":"https://github.com/blackjax-devs/blackjax","last_synced_at":"2025-10-21T19:58:48.582Z","repository":{"id":37206998,"uuid":"319886963","full_name":"blackjax-devs/blackjax","owner":"blackjax-devs","description":"BlackJAX is a Bayesian Inference library designed for ease of use, speed and modularity.","archived":false,"fork":false,"pushed_at":"2025-02-19T15:02:10.000Z","size":434404,"stargazers_count":886,"open_issues_count":97,"forks_count":110,"subscribers_count":15,"default_branch":"main","last_synced_at":"2025-03-19T02:19:34.090Z","etag":null,"topics":["bayesian-inference","hamiltonian-monte-carlo","probabilistic-programming","sampling-methods"],"latest_commit_sha":null,"homepage":"https://blackjax-devs.github.io/blackjax/","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/blackjax-devs.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":"CONTRIBUTING.md","funding":null,"license":"LICENSE","code_of_conduct":"CODE_OF_CONDUCT.md","threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":"GOVERNANCE.md","roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2020-12-09T08:12:12.000Z","updated_at":"2025-03-10T21:04:48.000Z","dependencies_parsed_at":"2023-02-17T01:01:16.052Z","dependency_job_id":"536aa6ca-a9ac-45e9-8d31-2c4cedc7eb62","html_url":"https://github.com/blackjax-devs/blackjax","commit_stats":{"total_commits":414,"total_committers":44,"mean_commits":9.409090909090908,"dds":0.6086956521739131,"last_synced_commit":"39760efcbf796c5010ac9065f2f86c6adf8d5945"},"previous_names":[],"tags_count":28,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/blackjax-devs%2Fblackjax","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/blackjax-devs%2Fblackjax/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/blackjax-devs%2Fblackjax/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/blackjax-devs%2Fblackjax/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/blackjax-devs","download_url":"https://codeload.github.com/blackjax-devs/blackjax/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":246535726,"owners_count":20793312,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["bayesian-inference","hamiltonian-monte-carlo","probabilistic-programming","sampling-methods"],"created_at":"2024-08-01T06:00:27.065Z","updated_at":"2025-10-21T19:58:43.544Z","avatar_url":"https://github.com/blackjax-devs.png","language":"Python","funding_links":[],"categories":["See also: other libraries in the JAX ecosystem","Software","Python","Libraries","Inference"],"sub_categories":[],"readme":"# BlackJAX\n![Continuous integration](https://github.com/blackjax-devs/blackjax/actions/workflows/test.yml/badge.svg)\n![codecov](https://codecov.io/gh/blackjax-devs/blackjax/branch/main/graph/badge.svg)\n![PyPI version](https://img.shields.io/pypi/v/blackjax)\n\n\n![BlackJAX animation: sampling BlackJAX with BlackJAX](./docs/examples/scatter.gif)\n\n## What is BlackJAX?\n\nBlackJAX is a library of samplers for [JAX](https://github.com/google/jax) that\nworks on CPU as well as GPU.\n\nIt is *not* a probabilistic programming library. However it integrates really\nwell with PPLs as long as they can provide a (potentially unnormalized)\nlog-probability density function compatible with JAX.\n\n## Who should use BlackJAX?\n\nBlackJAX should appeal to those who:\n- Have a logpdf and just need a sampler;\n- Need more than a general-purpose sampler;\n- Want to sample on GPU;\n- Want to build upon robust elementary blocks for their research;\n- Are building a probabilistic programming language;\n- Want to learn how sampling algorithms work.\n\n## Quickstart\n\n### Installation\n\nYou can install BlackJAX using `pip`:\n\n```bash\npip install blackjax\n```\n\nor via conda-forge:\n\n```bash\nconda install -c conda-forge blackjax\n```\n\nBlackJAX is written in pure Python but depends on XLA via JAX. By default, the\nversion of JAX that will be installed along with BlackJAX will make your code\nrun on CPU only. **If you want to use BlackJAX on GPU/TPU** we recommend you follow\n[these instructions](https://github.com/google/jax#installation) to install JAX\nwith the relevant hardware acceleration support.\n\n### Example\n\nLet us look at a simple self-contained example sampling with NUTS:\n\n```python\nimport jax\nimport jax.numpy as jnp\nimport jax.scipy.stats as stats\nimport numpy as np\n\nimport blackjax\n\nobserved = np.random.normal(10, 20, size=1_000)\ndef logdensity_fn(x):\n    logpdf = stats.norm.logpdf(observed, x[\"loc\"], x[\"scale\"])\n    return jnp.sum(logpdf)\n\n# Build the kernel\nstep_size = 1e-3\ninverse_mass_matrix = jnp.array([1., 1.])\nnuts = blackjax.nuts(logdensity_fn, step_size, inverse_mass_matrix)\n\n# Initialize the state\ninitial_position = {\"loc\": 1., \"scale\": 2.}\nstate = nuts.init(initial_position)\n\n# Iterate\nrng_key = jax.random.key(0)\nstep = jax.jit(nuts.step)\nfor i in range(100):\n    nuts_key = jax.random.fold_in(rng_key, i)\n    state, _ = step(nuts_key, state)\n```\n\nSee [the documentation](https://blackjax-devs.github.io/blackjax/index.html) for more examples of how to use the library: how to write inference loops for one or several chains, how to use the Stan warmup, etc.\n\n## Philosophy\n\n### What is BlackJAX?\n\nBlackJAX bridges the gap between \"one liner\" frameworks and modular, customizable\nlibraries.\n\nUsers can import the library and interact with robust, well-tested and performant\nsamplers with a few lines of code. These samplers are aimed at PPL developers,\nor people who have a logpdf and just need a sampler that works.\n\nBut the true strength of BlackJAX lies in its internals and how they can be used\nto experiment quickly on existing or new sampling schemes. This lower level\nexposes the building blocks of inference algorithms: integrators, proposal,\nmomentum generators, etc and makes it easy to combine them to build new\nalgorithms. It provides an opportunity to accelerate research on sampling\nalgorithms by providing robust, performant and reusable code.\n\n### Why BlackJAX?\n\nSampling algorithms are too often integrated into PPLs and not decoupled from\nthe rest of the framework, making them hard to use for people who do not need\nthe modeling language to build their logpdf. Their implementation is most of\nthe time monolithic and it is impossible to reuse parts of the algorithm to\nbuild custom kernels. BlackJAX solves both problems.\n\n### How does it work?\n\nBlackJAX allows to build arbitrarily complex algorithms because it is built\naround a very general pattern. Everything that takes a state and returns a state\nis a transition kernel, and is implemented as:\n\n```python\nnew_state, info =  kernel(rng_key, state)\n```\n\nkernels are stateless functions and all follow the same API; state and\ninformation related to the transition are returned separately. They can thus be\neasily composed and exchanged. We specialize these kernels by closure instead of\npassing parameters.\n\n## Contributions\n\nPlease follow our [short guide](https://github.com/blackjax-devs/blackjax/blob/main/CONTRIBUTING.md).\n\n## Citing Blackjax\n\nTo cite this repository:\n\n```\n@misc{cabezas2024blackjax,\n      title={BlackJAX: Composable {B}ayesian inference in {JAX}},\n      author={Alberto Cabezas and Adrien Corenflos and Junpeng Lao and Rémi Louf},\n      year={2024},\n      eprint={2402.10797},\n      archivePrefix={arXiv},\n      primaryClass={cs.MS}\n}\n```\nIn the above bibtex entry, names are in alphabetical order, the version number should be the last tag on the `main` branch.\n\n## Acknowledgements\n\nSome details of the NUTS implementation were largely inspired by\n[Numpyro](https://github.com/pyro-ppl/numpyro)'s.\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fblackjax-devs%2Fblackjax","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fblackjax-devs%2Fblackjax","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fblackjax-devs%2Fblackjax/lists"}