{"id":13688976,"url":"https://github.com/google-deepmind/distrax","last_synced_at":"2025-06-17T00:39:19.160Z","repository":{"id":38236719,"uuid":"353770320","full_name":"google-deepmind/distrax","owner":"google-deepmind","description":null,"archived":false,"fork":false,"pushed_at":"2025-05-16T17:27:35.000Z","size":644,"stargazers_count":568,"open_issues_count":44,"forks_count":32,"subscribers_count":16,"default_branch":"master","last_synced_at":"2025-05-16T18:32:20.277Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":null,"language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/google-deepmind.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":"CONTRIBUTING.md","funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null}},"created_at":"2021-04-01T17:03:49.000Z","updated_at":"2025-05-16T10:41:38.000Z","dependencies_parsed_at":"2023-09-07T20:34:49.066Z","dependency_job_id":"293f5ef3-4778-4114-a3e2-29d075c13cd5","html_url":"https://github.com/google-deepmind/distrax","commit_stats":null,"previous_names":["google-deepmind/distrax","deepmind/distrax"],"tags_count":12,"template":false,"template_full_name":null,"purl":"pkg:github/google-deepmind/distrax","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-deepmind%2Fdistrax","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-deepmind%2Fdistrax/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-deepmind%2Fdistrax/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-deepmind%2Fdistrax/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/google-deepmind","download_url":"https://codeload.github.com/google-deepmind/distrax/tar.gz/refs/heads/master","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-deepmind%2Fdistrax/sbom","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":260268635,"owners_count":22983601,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2024-08-02T15:01:29.414Z","updated_at":"2025-06-17T00:39:19.151Z","avatar_url":"https://github.com/google-deepmind.png","language":"Python","funding_links":[],"categories":["Python","Libraries"],"sub_categories":[],"readme":"# Distrax\n\n![CI status](https://github.com/deepmind/distrax/workflows/tests/badge.svg)\n![pypi](https://img.shields.io/pypi/v/distrax)\n\nDistrax is a lightweight library of probability distributions and bijectors. It\nacts as a JAX-native reimplementation of a subset of\n[TensorFlow Probability](https://www.tensorflow.org/probability) (TFP), with\nsome new features and emphasis on extensibility.\n\n## Installation\n\nYou can install the latest released version of Distrax from PyPI via:\n\n```sh\npip install distrax\n```\n\nor you can install the latest development version from GitHub:\n\n```sh\npip install git+https://github.com/deepmind/distrax.git\n```\n\nTo run the tests or\n[examples](https://github.com/deepmind/distrax/tree/master/examples) you will\nneed to install additional [requirements](https://github.com/deepmind/distrax/tree/master/requirements).\n\n## Design Principles\n\nThe general design principles for the DeepMind JAX Ecosystem are addressed in\n[this blog](https://deepmind.com/blog/article/using-jax-to-accelerate-our-research).\nAdditionally, Distrax places emphasis on the following:\n\n1. **Readability.** Distrax implementations are intended to be self-contained\nand read as close to the underlying math as possible.\n2. **Extensibility.** We have made it as simple as possible for users to define\ntheir own distribution or bijector. This is useful for example in reinforcement\nlearning, where users may wish to define custom behavior for probabilistic agent\npolicies.\n3. **Compatibility.** Distrax is not intended as a replacement for TFP, and TFP\ncontains many advanced features that we do not intend to replicate. To this end,\nwe have made the APIs for distributions and bijectors as cross-compatible as\npossible, and provide utilities for transforming between equivalent Distrax and\nTFP classes.\n\n## Features\n\n### Distributions\n\nDistributions in Distrax are simple to define and use, particularly if you're\nused to TFP. Let's compare the two side-by-side:\n\n```python\nimport distrax\nimport jax\nimport jax.numpy as jnp\n\nfrom tensorflow_probability.substrates import jax as tfp\ntfd = tfp.distributions\n\nkey = jax.random.PRNGKey(1234)\nmu = jnp.array([-1., 0., 1.])\nsigma = jnp.array([0.1, 0.2, 0.3])\n\ndist_distrax = distrax.MultivariateNormalDiag(mu, sigma)\ndist_tfp = tfd.MultivariateNormalDiag(mu, sigma)\n\nsamples = dist_distrax.sample(seed=key)\n\n# Both print 1.775\nprint(dist_distrax.log_prob(samples))\nprint(dist_tfp.log_prob(samples))\n```\n\nIn addition to behaving consistently, Distrax distributions and TFP\ndistributions are cross-compatible. For example:\n\n```python\nmu_0 = jnp.array([-1., 0., 1.])\nsigma_0 = jnp.array([0.1, 0.2, 0.3])\ndist_distrax = distrax.MultivariateNormalDiag(mu_0, sigma_0)\n\nmu_1 = jnp.array([1., 2., 3.])\nsigma_1 = jnp.array([0.2, 0.3, 0.4])\ndist_tfp = tfd.MultivariateNormalDiag(mu_1, sigma_1)\n\n# Both print 85.237\nprint(dist_distrax.kl_divergence(dist_tfp))\nprint(tfd.kl_divergence(dist_distrax, dist_tfp))\n```\n\nDistrax distributions implement the method `sample_and_log_prob`, which provides\nsamples and their log-probability in one line. For some distributions, this is\nmore efficient than calling separately `sample` and `log_prob`:\n\n```python\nmu = jnp.array([-1., 0., 1.])\nsigma = jnp.array([0.1, 0.2, 0.3])\ndist_distrax = distrax.MultivariateNormalDiag(mu, sigma)\n\nsamples = dist_distrax.sample(seed=key, sample_shape=())\nlog_prob = dist_distrax.log_prob(samples)\n\n# A one-line equivalent of the above is:\nsamples, log_prob = dist_distrax.sample_and_log_prob(seed=key, sample_shape=())\n```\n\nTFP distributions can be passed to Distrax meta-distributions as inputs. For\nexample:\n\n```python\nkey = jax.random.PRNGKey(1234)\n\nmu = jnp.array([-1., 0., 1.])\nsigma = jnp.array([0.2, 0.3, 0.4])\ndist_tfp = tfd.Normal(mu, sigma)\n\nmetadist_distrax = distrax.Independent(dist_tfp, reinterpreted_batch_ndims=1)\nsamples = metadist_distrax.sample(seed=key)\nprint(metadist_distrax.log_prob(samples))  # Prints 0.38871175\n```\n\nTo use Distrax distributions in TFP meta-distributions, Distrax provides the\nwrapper `to_tfp`. A wrapped Distrax distribution can be directly used in TFP:\n\n```python\nkey = jax.random.PRNGKey(1234)\n\ndistrax_dist = distrax.Normal(0., 1.)\nwrapped_dist = distrax.to_tfp(distrax_dist)\nmetadist_tfp = tfd.Sample(wrapped_dist, sample_shape=[3])\n\nsamples = metadist_tfp.sample(seed=key)\nprint(metadist_tfp.log_prob(samples))  # Prints -3.3409896\n```\n\n### Bijectors\n\nA \"bijector\" in Distrax is an invertible function that knows how to compute its\nJacobian determinant. Bijectors can be used to create complex distributions by\ntransforming simpler ones. Distrax bijectors are functionally similar to TFP\nbijectors, with a few API differences. Here is an example comparing the two:\n\n```python\nimport distrax\nimport jax.numpy as jnp\n\nfrom tensorflow_probability.substrates import jax as tfp\ntfb = tfp.bijectors\ntfd = tfp.distributions\n\n# Same distribution.\ndistrax.Transformed(distrax.Normal(loc=0., scale=1.), distrax.Tanh())\ntfd.TransformedDistribution(tfd.Normal(loc=0., scale=1.), tfb.Tanh())\n```\n\nAdditionally, Distrax bijectors can be composed and inverted:\n\n```python\nbij_distrax = distrax.Tanh()\nbij_tfp = tfb.Tanh()\n\n# Same bijector.\ninv_bij_distrax = distrax.Inverse(bij_distrax)\ninv_bij_tfp = tfb.Invert(bij_tfp)\n\n# These are both the identity bijector.\ndistrax.Chain([bij_distrax, inv_bij_distrax])\ntfb.Chain([bij_tfp, inv_bij_tfp])\n```\n\nAll TFP bijectors can be passed to Distrax, and can be freely composed with\nDistrax bijectors. For example, all of the following will work:\n\n```python\ndistrax.Inverse(tfb.Tanh())\n\ndistrax.Chain([tfb.Tanh(), distrax.Tanh()])\n\ndistrax.Transformed(tfd.Normal(loc=0., scale=1.), tfb.Tanh())\n```\n\nDistrax bijectors can also be passed to TFP, but first they must be transformed\nwith `to_tfp`:\n\n```python\nbij_distrax = distrax.to_tfp(distrax.Tanh())\n\ntfb.Invert(bij_distrax)\n\ntfb.Chain([tfb.Tanh(), bij_distrax])\n\ntfd.TransformedDistribution(tfd.Normal(loc=0., scale=1.), bij_distrax)\n```\n\nDistrax also comes with `Lambda`, a convenient wrapper for turning simple JAX\nfunctions into bijectors. Here are a few `Lambda` examples with their TFP\nequivalents:\n\n```python\ndistrax.Lambda(lambda x: x)\n# tfb.Identity()\n\ndistrax.Lambda(lambda x: 2*x + 3)\n# tfb.Chain([tfb.Shift(3), tfb.Scale(2)])\n\ndistrax.Lambda(jnp.sinh)\n# tfb.Sinh()\n\ndistrax.Lambda(lambda x: jnp.sinh(2*x + 3))\n# tfb.Chain([tfb.Sinh(), tfb.Shift(3), tfb.Scale(2)])\n```\n\nUnlike TFP, bijectors in Distrax do not take `event_ndims` as an argument when\nthey compute the Jacobian determinant. Instead, Distrax assumes that the number\nof event dimensions is statically known to every bijector, and uses\n`Block` to lift bijectors to a different number of dimensions. For example:\n\n```python\nx = jnp.zeros([2, 3, 4])\n\n# In TFP, `event_ndims` can be passed to the bijector.\nbij_tfp = tfb.Tanh()\nld_1 = bij_tfp.forward_log_det_jacobian(x, event_ndims=0)  # Shape = [2, 3, 4]\n\n# Distrax assumes `Tanh` is a scalar bijector by default.\nbij_distrax = distrax.Tanh()\nld_2 = bij_distrax.forward_log_det_jacobian(x)  # ld_1 == ld_2\n\n# With `event_ndims=2`, TFP sums the last 2 dimensions of the log det.\nld_3 = bij_tfp.forward_log_det_jacobian(x, event_ndims=2)  # Shape = [2]\n\n# Distrax treats the number of dimensions statically.\nbij_distrax = distrax.Block(bij_distrax, ndims=2)\nld_4 = bij_distrax.forward_log_det_jacobian(x)  # ld_3 == ld_4\n```\n\nDistrax bijectors implement the method `forward_and_log_det` (some bijectors\nadditionally implement `inverse_and_log_det`), which allows to obtain the\n`forward` mapping and its log Jacobian determinant in one line. For some\nbijectors, this is more efficient than calling separately `forward` and\n`forward_log_det_jacobian`. (Analogously, when available, `inverse_and_log_det`\ncan be more efficient than `inverse` and `inverse_log_det_jacobian`.)\n\n```python\nx = jnp.zeros([2, 3, 4])\nbij_distrax = distrax.Tanh()\n\ny = bij_distrax.forward(x)\nld = bij_distrax.forward_log_det_jacobian(x)\n\n# A one-line equivalent of the above is:\ny, ld = bij_distrax.forward_and_log_det(x)\n```\n\n### Jitting Distrax\n\nDistrax distributions and bijectors can be passed as arguments to jitted\nfunctions. User-defined distributions and bijectors get this property for free\nby subclassing `distrax.Distribution` and `distrax.Bijector` respectively. For\nexample:\n\n```python\nmu_0 = jnp.array([-1., 0., 1.])\nsigma_0 = jnp.array([0.1, 0.2, 0.3])\ndist_0 = distrax.MultivariateNormalDiag(mu_0, sigma_0)\n\nmu_1 = jnp.array([1., 2., 3.])\nsigma_1 = jnp.array([0.2, 0.3, 0.4])\ndist_1 = distrax.MultivariateNormalDiag(mu_1, sigma_1)\n\njitted_kl = jax.jit(lambda d_0, d_1: d_0.kl_divergence(d_1))\n\n# Both print 85.237\nprint(jitted_kl(dist_0, dist_1))\nprint(dist_0.kl_divergence(dist_1))\n```\n\n##### A note about `vmap` and `pmap`\n\nThe serialization logic that enables Distrax objects to be passed as arguments\nto jitted functions also enables functions to map over them as data using\n`jax.vmap` and `jax.pmap`.\n\nHowever, ***support for this behavior is experimental and incomplete. Use\ncaution when applying `jax.vmap` or `jax.pmap` to functions that take Distrax\nobjects as arguments, or return Distrax objects.***\n\nSimple objects such as `distrax.Categorical` may behave correctly under these\ntransformations, but more complex objects such as\n`distrax.MultivariateNormalDiag` may generate exceptions when used as inputs or\noutputs of a `vmap`-ed or `pmap`-ed function.\n\n\n### Subclassing Distributions and Bijectors\n\nUser-defined distributions can be created by subclassing `distrax.Distribution`.\nThis can be achieved by implementing only a few methods:\n\n```python\nclass MyDistribution(distrax.Distribution):\n\n  def __init__(self, ...):\n    ...\n\n  def _sample_n(self, key, n):\n    samples = ...\n    return samples\n\n  def log_prob(self, value):\n    log_prob = ...\n    return log_prob\n\n  def event_shape(self):\n    event_shape = ...\n    return event_shape\n\n  def _sample_n_and_log_prob(self, key, n):\n    # Optional. Only when more efficient implementation is possible.\n    samples, log_prob = ...\n    return samples, log_prob\n```\n\nSimilarly, more complicated bijectors can be created by subclassing\n`distrax.Bijector`. This can be achieved by implementing only one or two class\nmethods:\n\n```python\nclass MyBijector(distrax.Bijector):\n\n  def __init__(self, ...):\n    super().__init__(...)\n\n  def forward_and_log_det(self, x):\n    y = ...\n    logdet = ...\n    return y, logdet\n\n  def inverse_and_log_det(self, y):\n    # Optional. Can be omitted if inverse methods are not needed.\n    x = ...\n    logdet = ...\n    return x, logdet\n```\n\n## Examples\n\nThe `examples` directory contains some representative examples of full programs\nthat use Distrax.\n\n`hmm.py` demonstrates how to use `distrax.HMM` to combine distributions that\nmodel the initial states, transitions, and observation distributions of a\nHidden Markov Model, and infer the latent rates and state transitions in a\nchanging noisy signal.\n\n`vae.py` contains an example implementation of a variational auto-encoder that\nis trained to model the binarized MNIST dataset as a joint `distrax.Bernoulli`\ndistribution over the pixels.\n\n`flow.py` illustrates a simple example of modelling MNIST data using\n`distrax.MaskedCoupling` layers to implement a normalizing flow, and training\nthe model with gradient descent.\n\n## Acknowledgements\n\nWe greatly appreciate the ongoing support of the TensorFlow Probability authors\nin assisting with the design and cross-compatibility of Distrax.\n\nSpecial thanks to Aleyna Kara and Kevin Murphy for contributing the code upon\nwhich the Hidden Markov Model and associated example are based.\n\n## Citing Distrax\n\nThis repository is part of the DeepMind JAX Ecosystem, to cite Distrax\nplease use the citation:\n\n```bibtex\n@software{deepmind2020jax,\n  title = {The {D}eep{M}ind {JAX} {E}cosystem},\n  author = {DeepMind and Babuschkin, Igor and Baumli, Kate and Bell, Alison and Bhupatiraju, Surya and Bruce, Jake and Buchlovsky, Peter and Budden, David and Cai, Trevor and Clark, Aidan and Danihelka, Ivo and Dedieu, Antoine and Fantacci, Claudio and Godwin, Jonathan and Jones, Chris and Hemsley, Ross and Hennigan, Tom and Hessel, Matteo and Hou, Shaobo and Kapturowski, Steven and Keck, Thomas and Kemaev, Iurii and King, Michael and Kunesch, Markus and Martens, Lena and Merzic, Hamza and Mikulik, Vladimir and Norman, Tamara and Papamakarios, George and Quan, John and Ring, Roman and Ruiz, Francisco and Sanchez, Alvaro and Sartran, Laurent and Schneider, Rosalia and Sezener, Eren and Spencer, Stephen and Srinivasan, Srivatsan and Stanojevi\\'{c}, Milo\\v{s} and Stokowiec, Wojciech and Wang, Luyu and Zhou, Guangyao and Viola, Fabio},\n  url = {http://github.com/deepmind},\n  year = {2020},\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fgoogle-deepmind%2Fdistrax","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fgoogle-deepmind%2Fdistrax","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fgoogle-deepmind%2Fdistrax/lists"}