{"id":22771850,"url":"https://github.com/stonet2000/jax-bandits","last_synced_at":"2025-06-27T23:02:51.670Z","repository":{"id":106468256,"uuid":"532106514","full_name":"StoneT2000/jax-bandits","owner":"StoneT2000","description":"bandit algorithms in jax","archived":false,"fork":false,"pushed_at":"2022-09-06T23:56:33.000Z","size":296,"stargazers_count":2,"open_issues_count":1,"forks_count":0,"subscribers_count":1,"default_branch":"main","last_synced_at":"2025-03-30T12:15:08.939Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":null,"language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/StoneT2000.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2022-09-02T23:33:23.000Z","updated_at":"2022-09-06T17:14:41.000Z","dependencies_parsed_at":null,"dependency_job_id":"f96a7526-c930-48b5-bd78-c88b9667baa6","html_url":"https://github.com/StoneT2000/jax-bandits","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"purl":"pkg:github/StoneT2000/jax-bandits","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/StoneT2000%2Fjax-bandits","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/StoneT2000%2Fjax-bandits/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/StoneT2000%2Fjax-bandits/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/StoneT2000%2Fjax-bandits/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/StoneT2000","download_url":"https://codeload.github.com/StoneT2000/jax-bandits/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/StoneT2000%2Fjax-bandits/sbom","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":262347469,"owners_count":23296893,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2024-12-11T16:17:33.691Z","updated_at":"2025-06-27T23:02:51.642Z","avatar_url":"https://github.com/StoneT2000.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# Jax Bandits\n\nA fast Jax based library for multi-armed bandit problems.\n\nIncludes the following algorithms\n- UCB1, UCB2\n- Thompson Sampling\n- Epsilon Greedy\n\nVia Jax and `vmap`, you can easily sample with an algorithm e.g. Epsilon Greedy 50 million times per second if you wanted to on a single GPU.\n\n## Installation\n\nThe package only depends on [jax](https://github.com/google/jax) and [flax](https://github.com/google/flax). Follow instructions on those repositories for how to install\n\nTo install this package, run\n\n```\npip install --upgrade git+https://github.com/StoneT2000/jax-bandits.git\n```\n\n## Usage\n\nThis library provides a simple jax based environment interface for multi-armed bandits as well as algorithms.\n\nThe following shows how to initialize an environment and an algorithm.\n\n```python\nimport jax\nimport numpy as np\nfrom jaxbandits import BernoulliBandits, algos\n\n# set backend to CPU as usually it's faster due to the dispatch overhead on the GPU.\n# GPU is useful if you plan to vmap the functions\njax.config.update('jax_platform_name', 'cpu')\n\nkey = jax.random.PRNGKey(0)\nkey, env_key = jax.random.split(key)\n\n# First intialize a bandit environment e.g. Bernoulli Bandits which comes with the environment state and functions\nenv = BernoulliBandits.create(env_key, arms=16)\n# Then we initialize an algorithm e.g. Thompson Sampling which comes with the algo state and functions\nalgo = algos.ThompsonSampling.create(env.arms)\n```\n\nTo then start experimenting and solving, run\n\n```python\nN = 4096\nregrets = []\nfor i in range(N):\n    key, step_key = jax.random.split(key)\n    # perform one update step in the algorithm. Provide RNG, algorithm state, and the environment. \n    # Note that since all things jax are immutable, an updated env and algo object is returned as well\n    algo, env, action, reward = algo.update_step(step_key, env)\n    \n    # store the regret values\n    regret = env.regret(action)\n    regrets += [regret]\ncumulative_regret = np.cumsum(np.array(regrets))\n```\n\nFor a packaged, jitted version of the above loop, you can use the `experiment` function in the package\n\n```python\nfrom jaxbandits import experiment\nres = experiment(key, env, algo, N)\ncumulative_regret = np.cumsum(np.array(res[\"regret\"]))\nrewards = np.array(res[\"reward\"])\nactions = np.array(res[\"action\"])\n```\n\nThe above code can be found in [examples/experiment.py](https://github.com/StoneT2000/jax-bandits/blob/main/examples/experiment.py). Simply change the environment class and algorithm class to test them out.\n\nDue to the high volume of small operations, usually using the CPU backend will be faster. The GPU backend will be better if you plan to `vmap/pmap` the code, which is all possible as all of the algorithms and environments are registered as pytree nodes (via the `@flax.struct.dataclass` decorator). \n\nTo run a batch of experiments, simply `vmap` the `experiment` function. Example parallelization code is provided in [examples/parallel.py](https://github.com/StoneT2000/jax-bandits/blob/main/examples/parallel.py)\n\n## Algos\n\nThe following algos are accessible as so\n\n```python\nfrom jaxbandits import algos\nalgos.ThompsonSampling\nalgos.UCB1\nalgos.UCB2\nalgos.EpsilonGreedy\n```\n\n## Example Results\n\nRun \n```\npython scripts/bench.py\n```\n\nto generate the following figure, showing a comparison of algorithms.\n\n![](assets/BernoulliBandits_results.png)","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fstonet2000%2Fjax-bandits","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fstonet2000%2Fjax-bandits","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fstonet2000%2Fjax-bandits/lists"}