https://github.com/dfm/emcee-jax
An experiment: emcee implemented in JAX
https://github.com/dfm/emcee-jax
Last synced: about 1 year ago
JSON representation
An experiment: emcee implemented in JAX
- Host: GitHub
- URL: https://github.com/dfm/emcee-jax
- Owner: dfm
- License: mit
- Created: 2022-06-10T00:14:46.000Z (almost 4 years ago)
- Default Branch: main
- Last Pushed: 2022-07-01T15:27:55.000Z (almost 4 years ago)
- Last Synced: 2025-03-27T10:37:18.584Z (about 1 year ago)
- Language: Python
- Size: 81.1 KB
- Stars: 25
- Watchers: 3
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- Changelog: HISTORY.rst
- License: LICENSE
- Code of conduct: CODE_OF_CONDUCT.md
Awesome Lists containing this project
README
# emcee-jax
An experiment.
A simple example:
```python
>>> import jax
>>> import emcee_jax
>>>
>>> def log_prob(theta, a1=100.0, a2=20.0):
... x1, x2 = theta
... return -(a1 * (x2 - x1**2)**2 + (1 - x1)**2) / a2
...
>>> num_walkers, num_steps = 100, 1000
>>> key1, key2, key3 = jax.random.split(jax.random.PRNGKey(0), 3)
>>> coords = jax.random.normal(key1, shape=(num_walkers, 2))
>>> sampler = emcee_jax.EnsembleSampler(log_prob)
>>> state = sampler.init(key2, coords)
>>> trace = sampler.sample(key3, state, num_steps)
```
An example using PyTrees as input coordinates:
```python
>>> import jax
>>> import emcee_jax
>>>
>>> def log_prob(theta, a1=100.0, a2=20.0):
... x1, x2 = theta["x"], theta["y"]
... return -(a1 * (x2 - x1**2)**2 + (1 - x1)**2) / a2
...
>>> num_walkers, num_steps = 100, 1000
>>> key1, key2, key3, key4 = jax.random.split(jax.random.PRNGKey(0), 4)
>>> coords = {
... "x": jax.random.normal(key1, shape=(num_walkers,)),
... "y": jax.random.normal(key2, shape=(num_walkers,)),
... }
>>> sampler = emcee_jax.EnsembleSampler(log_prob)
>>> state = sampler.init(key3, coords)
>>> trace = sampler.sample(key4, state, num_steps)
```
An example that includes deterministics:
```python
>>> import jax
>>> import emcee_jax
>>>
>>> def log_prob(theta, a1=100.0, a2=20.0):
... x1, x2 = theta
... some_number = x1 + jax.numpy.sin(x2)
... log_prob_value = -(a1 * (x2 - x1**2)**2 + (1 - x1)**2) / a2
...
... # This second argument can be any PyTree
... return log_prob_value, {"some_number": some_number}
...
>>> num_walkers, num_steps = 100, 1000
>>> key1, key2, key3 = jax.random.split(jax.random.PRNGKey(0), 3)
>>> coords = jax.random.normal(key1, shape=(num_walkers, 2))
>>> sampler = emcee_jax.EnsembleSampler(log_prob)
>>> state = sampler.init(key2, coords)
>>> trace = sampler.sample(key3, state, num_steps)
```
You can even use pure-Python log probability functions:
```python
>>> import jax
>>> import numpy as np
>>> import emcee_jax
>>> from emcee_jax.host_callback import wrap_python_log_prob_fn
>>>
>>> # A log prob function that uses numpy, not jax.numpy inside
>>> @wrap_python_log_prob_fn
... def log_prob(theta, a1=100.0, a2=20.0):
... x1, x2 = theta
... return -(a1 * np.square(x2 - x1**2) + np.square(1 - x1)) / a2
...
>>> num_walkers, num_steps = 100, 1000
>>> key1, key2, key3 = jax.random.split(jax.random.PRNGKey(0), 3)
>>> coords = jax.random.normal(key1, shape=(num_walkers, 2))
>>> sampler = emcee_jax.EnsembleSampler(log_prob)
>>> state = sampler.init(key2, coords)
>>> trace = sampler.sample(key3, state, num_steps)
```