{"id":17190909,"url":"https://github.com/dfm/numpyro-ext","last_synced_at":"2025-08-21T02:31:02.493Z","repository":{"id":51060596,"uuid":"518250683","full_name":"dfm/numpyro-ext","owner":"dfm","description":"A miscellaneous set of helper functions, custom distributions, and other utilities that I find useful when using NumPyro in my work","archived":false,"fork":false,"pushed_at":"2025-02-01T09:36:25.000Z","size":88,"stargazers_count":24,"open_issues_count":1,"forks_count":4,"subscribers_count":2,"default_branch":"main","last_synced_at":"2025-07-01T22:07:55.050Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/dfm.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":"CONTRIBUTING.md","funding":null,"license":"LICENSE","code_of_conduct":"CODE_OF_CONDUCT.md","threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2022-07-26T23:54:43.000Z","updated_at":"2025-02-07T17:49:10.000Z","dependencies_parsed_at":"2023-01-23T13:45:20.560Z","dependency_job_id":"a0b39d09-c48c-47b1-a673-d4c22e522538","html_url":"https://github.com/dfm/numpyro-ext","commit_stats":{"total_commits":18,"total_committers":3,"mean_commits":6.0,"dds":0.2777777777777778,"last_synced_commit":"b52b0788856392afd4a26f66d6ffdd3837f0e60a"},"previous_names":[],"tags_count":7,"template":false,"template_full_name":null,"purl":"pkg:github/dfm/numpyro-ext","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/dfm%2Fnumpyro-ext","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/dfm%2Fnumpyro-ext/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/dfm%2Fnumpyro-ext/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/dfm%2Fnumpyro-ext/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/dfm","download_url":"https://codeload.github.com/dfm/numpyro-ext/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/dfm%2Fnumpyro-ext/sbom","scorecard":null,"host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":271416730,"owners_count":24755941,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","status":"online","status_checked_at":"2025-08-21T02:00:08.990Z","response_time":74,"last_error":null,"robots_txt_status":"success","robots_txt_updated_at":"2025-07-24T06:49:26.215Z","robots_txt_url":"https://github.com/robots.txt","online":true,"can_crawl_api":true,"host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2024-10-15T01:24:23.857Z","updated_at":"2025-08-21T02:31:02.201Z","avatar_url":"https://github.com/dfm.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# Extensions for NumPyro\n\nThis library includes a miscellaneous set of helper functions, custom\ndistributions, and other utilities that I find useful when using\n[NumPyro](https://num.pyro.ai) in my work.\n\n## Installation\n\nSince NumPyro, and hence this library, are built on top of JAX, it's typically\ngood practice to start by installing JAX following [the installation\ninstructions](https://jax.readthedocs.io/en/latest/#installation). Then, you can\ninstall this library using pip:\n\n```bash\npython -m pip install numpyro-ext\n```\n\n## Usage\n\nSince this README is checked using `doctest`, let's start by importing some\ncommon modules that we'll need in all our examples:\n\n```python\n\u003e\u003e\u003e import jax\n\u003e\u003e\u003e import jax.numpy as jnp\n\u003e\u003e\u003e import numpyro\n\u003e\u003e\u003e import numpyro_ext\n\n```\n\n### Distributions\n\nThe tradition is to import `numpyro_ext.distributions` as `distx` to\ndifferentiate from `numpyro.distributions`, which is imported as `dist`:\n\n```python\n\u003e\u003e\u003e from numpyro import distributions as dist\n\u003e\u003e\u003e from numpyro_ext import distributions as distx\n\u003e\u003e\u003e key = jax.random.PRNGKey(0)\n\n```\n\n#### Angle\n\nA uniform distribution over angles in radians. The actual sampling is performed\nin the two-dimensional vector space proportional to `(sin(theta), cos(theta))`\nso that the sampler doesn't see a discontinuity at pi.\n\n```python\n\u003e\u003e\u003e angle = distx.Angle()\n\u003e\u003e\u003e print(angle.sample(key, (2, 3)))\n[[ 0.4...]\n [ 2.4...]]\n\n```\n\n#### UnitDisk\n\nA uniform distribution over two-dimensional points within the disk of radius 1.\nThis means that the sum over squares of the last dimension of a random variable\ngenerated from this distribution will always be less than 1.\n\n```python\n\u003e\u003e\u003e unit_disk = distx.UnitDisk()\n\u003e\u003e\u003e u = unit_disk.sample(key, (5,))\n\u003e\u003e\u003e print(jnp.sum(u**2, axis=-1))\n[0.07...]\n\n```\n\n####  NoncentralChi2\n\nA [non-central chi-squared\ndistribution](https://en.wikipedia.org/wiki/Noncentral_chi-squared_distribution).\nTo use this distribution, you'll need to install the optional\n`tensorflow-probability` dependency.\n\n```python\n\u003e\u003e\u003e ncx2 = distx.NoncentralChi2(df=3, nc=2.)\n\u003e\u003e\u003e print(ncx2.sample(key, (5,)))\n[2.19...]\n\n```\n\n#### MarginalizedLinear\n\nThe marginalized product of two (possibly multivariate) normal distributions\nwith a linear relationship between them. The mathematical details of these\nmodels are discussed in detail in [this note](https://arxiv.org/abs/2005.14199),\nand this distribution implements the math presented there, in a computationally\nefficient way, assuming that the number of marginalized parameters is small\ncompared to the size of the dataset.\n\nThe following example shows a particularly simple example of a\nfully-marginalized model for fitting a line to data:\n\n```python\n\u003e\u003e\u003e def model(x, y=None):\n...     design_matrix = jnp.vander(x, 2)\n...     prior = dist.Normal(0.0, 1.0)\n...     data = dist.Normal(0.0, 2.0)\n...     numpyro.sample(\n...         \"y\",\n...         distx.MarginalizedLinear(design_matrix, prior, data),\n...         obs=y\n...     )\n...\n\n```\n\nThings get a little more interesting when the design matrix and/or the\ndistributions are functions of non-linear parameters. For example, if we want to\nfind the period of a sinusoidal signal, also fitting for some unknown excess\nmeasurement uncertainty (often called \"jitter\") we can use the following model:\n\n```python\n\u003e\u003e\u003e def model(x, y_err, y=None):\n...     period = numpyro.sample(\"period\", dist.Uniform(1.0, 250.0))\n...     ln_jitter = numpyro.sample(\"ln_jitter\", dist.Normal(0.0, 2.0))\n...     design_matrix = jnp.stack(\n...         [\n...             jnp.sin(2 * jnp.pi * x / period),\n...             jnp.cos(2 * jnp.pi * x / period),\n...             jnp.ones_like(x),\n...         ],\n...         axis=-1,\n...     )\n...     prior = dist.Normal(0.0, 10.0).expand([3])\n...     data = dist.Normal(0.0, jnp.sqrt(y_err**2 + jnp.exp(2*ln_jitter)))\n...     numpyro.sample(\n...         \"y\",\n...         distx.MarginalizedLinear(design_matrix, prior, data),\n...         obs=y\n...     )\n...\n\u003e\u003e\u003e x = jnp.linspace(-1.0, 1.0, 5)\n\u003e\u003e\u003e samples = numpyro.infer.Predictive(model, num_samples=2)(key, x, 0.1)\n\u003e\u003e\u003e print(samples[\"period\"])\n[... ...]\n\u003e\u003e\u003e print(samples[\"y\"])\n[[... ... ...]\n [... ... ...]]\n\n```\n\nIt's often useful to also track conditional samples of the marginalized\nparameters during inference. The conditional distribution can be accessed using\nthe `conditional` method on `MarginalizedLinear`:\n\n```python\n\u003e\u003e\u003e x = jnp.linspace(-1.0, 1.0, 5)\n\u003e\u003e\u003e y = jnp.sin(x)  # just some fake data\n\u003e\u003e\u003e design_matrix = jnp.vander(x, 2)\n\u003e\u003e\u003e prior = dist.Normal(0.0, 1.0)\n\u003e\u003e\u003e data = dist.Normal(0.0, 2.0)\n\u003e\u003e\u003e marg = distx.MarginalizedLinear(design_matrix, prior, data)\n\u003e\u003e\u003e cond = marg.conditional(y)\n\u003e\u003e\u003e print(type(cond).__name__)\nMultivariateNormal\n\u003e\u003e\u003e print(cond.sample(key, (3,)))\n[[...]\n [...]\n [...]]\n\n```\n\n### Optimization\n\nThe inference lore is a little mixed on the benefits of optimization as an\ninitialization tool for MCMC, but I find that at least in a lot of astronomy\napplications, an initial optimization can make a huge difference in performance.\nEven if you don't want to use the optimization results as an initialization, it\ncan still sometimes be useful to numerically search for the maximum _a\nposteriori_ parameters for your model. However, the NumPyro interface for these\ntypes of optimization isn't terribly user-friendly, so this library provides\nsome helpers to make it a little more straightforward.\n\nBy default, this optimization uses the wrappers of scipy's optimization routines\nprovided by the [JAXopt](https://github.com/google/jaxopt) library, so you'll\nneed to install JAXopt:\n\n```bash\npython -m pip install jaxopt\n```\n\nbefore running these examples.\n\nThe following example shows a simple optimization of a model with a single\nparameter:\n\n```python\n\u003e\u003e\u003e from numpyro_ext import optim as optimx\n\u003e\u003e\u003e\n\u003e\u003e\u003e def model(y=None):\n...     x = numpyro.sample(\"x\", dist.Normal(0.0, 1.0))\n...     numpyro.sample(\"y\", dist.Normal(x, 2.0), obs=y)\n...\n\u003e\u003e\u003e soln = optimx.optimize(model)(key, y=0.5)\n\n```\n\nBy default, the optimization starts from a prior sample, but you can provide\ncustom initial coordinates as follows:\n\n```python\n\u003e\u003e\u003e soln = optimx.optimize(model, start={\"x\": 12.3})(key, y=0.5)\n\n```\n\nSimilarly, if you only want to optimize a subset of the parameters, you can\nprovide a list of parameters to target:\n\n```python\n\u003e\u003e\u003e soln = optimx.optimize(model, sites=[\"x\"])(key, y=0.5)\n\n```\n\n### Information matrix computation\n\nThe Fisher information matrix for models with Gaussian likelihoods is\n[straightforward to\ncompute](https://en.wikipedia.org/wiki/Fisher_information#Multivariate_normal_distribution),\nand this library provides a helper function for automating this computation:\n\n```python\n\u003e\u003e\u003e from numpyro_ext import information\n\u003e\u003e\u003e\n\u003e\u003e\u003e def model(x, y=None):\n...     a = numpyro.sample(\"a\", dist.Normal(0.0, 1.0))\n...     b = numpyro.sample(\"b\", dist.Normal(0.0, 1.0))\n...     log_alpha = numpyro.sample(\"log_alpha\", dist.Normal(0.0, 1.0))\n...     cov = jnp.exp(log_alpha - 0.5 * (x[:, None] - x[None, :])**2)\n...     cov += 0.1 * jnp.eye(len(x))\n...     numpyro.sample(\n...         \"y\",\n...         dist.MultivariateNormal(loc=a * x + b, covariance_matrix=cov),\n...         obs=y,\n...     )\n...\n\u003e\u003e\u003e x = jnp.linspace(-1.0, 1.0, 5)\n\u003e\u003e\u003e y = jnp.sin(x)  # the input data just needs to have the right shape\n\u003e\u003e\u003e params = {\"a\": 0.5, \"b\": -0.2, \"log_alpha\": -0.5}\n\u003e\u003e\u003e info = information(model)(params, x, y=y)\n\u003e\u003e\u003e print(info)\n{'a': {'a': ..., 'b': ... 'log_alpha': ...}, 'b': ...}\n\n```\n\nThe returned information matrix is a nested dictionary of dictionaries, indexed\nby pairs of parameter names, where the values are the corresponding blocks of\nthe information matrix.\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fdfm%2Fnumpyro-ext","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fdfm%2Fnumpyro-ext","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fdfm%2Fnumpyro-ext/lists"}