{"id":17191128,"url":"https://github.com/dfm/emcee-jax","last_synced_at":"2025-04-13T19:51:05.037Z","repository":{"id":37872877,"uuid":"501846629","full_name":"dfm/emcee-jax","owner":"dfm","description":"An experiment: emcee implemented in JAX","archived":false,"fork":false,"pushed_at":"2022-07-01T15:27:55.000Z","size":83,"stargazers_count":25,"open_issues_count":0,"forks_count":0,"subscribers_count":3,"default_branch":"main","last_synced_at":"2025-03-27T10:37:18.584Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":null,"language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/dfm.png","metadata":{"files":{"readme":"README.md","changelog":"HISTORY.rst","contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":"CODE_OF_CONDUCT.md","threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2022-06-10T00:14:46.000Z","updated_at":"2024-08-27T13:00:05.000Z","dependencies_parsed_at":"2022-08-18T14:01:16.657Z","dependency_job_id":null,"html_url":"https://github.com/dfm/emcee-jax","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/dfm%2Femcee-jax","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/dfm%2Femcee-jax/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/dfm%2Femcee-jax/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/dfm%2Femcee-jax/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/dfm","download_url":"https://codeload.github.com/dfm/emcee-jax/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":248773753,"owners_count":21159517,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2024-10-15T01:24:57.786Z","updated_at":"2025-04-13T19:51:05.014Z","avatar_url":"https://github.com/dfm.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# emcee-jax\n\nAn experiment.\n\nA simple example:\n\n```python\n\u003e\u003e\u003e import jax\n\u003e\u003e\u003e import emcee_jax\n\u003e\u003e\u003e\n\u003e\u003e\u003e def log_prob(theta, a1=100.0, a2=20.0):\n...     x1, x2 = theta\n...     return -(a1 * (x2 - x1**2)**2 + (1 - x1)**2) / a2\n...\n\u003e\u003e\u003e num_walkers, num_steps = 100, 1000\n\u003e\u003e\u003e key1, key2, key3 = jax.random.split(jax.random.PRNGKey(0), 3)\n\u003e\u003e\u003e coords = jax.random.normal(key1, shape=(num_walkers, 2))\n\u003e\u003e\u003e sampler = emcee_jax.EnsembleSampler(log_prob)\n\u003e\u003e\u003e state = sampler.init(key2, coords)\n\u003e\u003e\u003e trace = sampler.sample(key3, state, num_steps)\n\n```\n\nAn example using PyTrees as input coordinates:\n\n```python\n\u003e\u003e\u003e import jax\n\u003e\u003e\u003e import emcee_jax\n\u003e\u003e\u003e\n\u003e\u003e\u003e def log_prob(theta, a1=100.0, a2=20.0):\n...     x1, x2 = theta[\"x\"], theta[\"y\"]\n...     return -(a1 * (x2 - x1**2)**2 + (1 - x1)**2) / a2\n...\n\u003e\u003e\u003e num_walkers, num_steps = 100, 1000\n\u003e\u003e\u003e key1, key2, key3, key4 = jax.random.split(jax.random.PRNGKey(0), 4)\n\u003e\u003e\u003e coords = {\n...     \"x\": jax.random.normal(key1, shape=(num_walkers,)),\n...     \"y\": jax.random.normal(key2, shape=(num_walkers,)),\n... }\n\u003e\u003e\u003e sampler = emcee_jax.EnsembleSampler(log_prob)\n\u003e\u003e\u003e state = sampler.init(key3, coords)\n\u003e\u003e\u003e trace = sampler.sample(key4, state, num_steps)\n\n```\n\nAn example that includes deterministics:\n\n```python\n\u003e\u003e\u003e import jax\n\u003e\u003e\u003e import emcee_jax\n\u003e\u003e\u003e\n\u003e\u003e\u003e def log_prob(theta, a1=100.0, a2=20.0):\n...     x1, x2 = theta\n...     some_number = x1 + jax.numpy.sin(x2)\n...     log_prob_value = -(a1 * (x2 - x1**2)**2 + (1 - x1)**2) / a2\n...\n...     # This second argument can be any PyTree\n...     return log_prob_value, {\"some_number\": some_number}\n...\n\u003e\u003e\u003e num_walkers, num_steps = 100, 1000\n\u003e\u003e\u003e key1, key2, key3 = jax.random.split(jax.random.PRNGKey(0), 3)\n\u003e\u003e\u003e coords = jax.random.normal(key1, shape=(num_walkers, 2))\n\u003e\u003e\u003e sampler = emcee_jax.EnsembleSampler(log_prob)\n\u003e\u003e\u003e state = sampler.init(key2, coords)\n\u003e\u003e\u003e trace = sampler.sample(key3, state, num_steps)\n\n```\n\nYou can even use pure-Python log probability functions:\n\n```python\n\u003e\u003e\u003e import jax\n\u003e\u003e\u003e import numpy as np\n\u003e\u003e\u003e import emcee_jax\n\u003e\u003e\u003e from emcee_jax.host_callback import wrap_python_log_prob_fn\n\u003e\u003e\u003e\n\u003e\u003e\u003e # A log prob function that uses numpy, not jax.numpy inside\n\u003e\u003e\u003e @wrap_python_log_prob_fn\n... def log_prob(theta, a1=100.0, a2=20.0):\n...     x1, x2 = theta\n...     return -(a1 * np.square(x2 - x1**2) + np.square(1 - x1)) / a2\n...\n\u003e\u003e\u003e num_walkers, num_steps = 100, 1000\n\u003e\u003e\u003e key1, key2, key3 = jax.random.split(jax.random.PRNGKey(0), 3)\n\u003e\u003e\u003e coords = jax.random.normal(key1, shape=(num_walkers, 2))\n\u003e\u003e\u003e sampler = emcee_jax.EnsembleSampler(log_prob)\n\u003e\u003e\u003e state = sampler.init(key2, coords)\n\u003e\u003e\u003e trace = sampler.sample(key3, state, num_steps)\n\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fdfm%2Femcee-jax","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fdfm%2Femcee-jax","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fdfm%2Femcee-jax/lists"}