{"id":25472858,"url":"https://github.com/epignatelli/navix","last_synced_at":"2025-04-06T08:11:50.042Z","repository":{"id":174936595,"uuid":"653048276","full_name":"epignatelli/navix","owner":"epignatelli","description":"Accelerated minigrid environments with JAX","archived":false,"fork":false,"pushed_at":"2024-08-01T18:29:27.000Z","size":4755,"stargazers_count":132,"open_issues_count":12,"forks_count":17,"subscribers_count":3,"default_branch":"main","last_synced_at":"2025-03-30T07:07:55.695Z","etag":null,"topics":["deep-reinforcement-learning","deep-rl","environment","gridworld","gridworld-environment","minigrid","reinforcement-learning","rl"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/epignatelli.png","metadata":{"files":{"readme":"README.md","changelog":"CHANGELOG.md","contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":"CITATION.cff","codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":"AUTHORS","dei":null,"publiccode":null,"codemeta":null}},"created_at":"2023-06-13T10:00:44.000Z","updated_at":"2025-03-03T10:50:20.000Z","dependencies_parsed_at":"2023-12-30T06:12:09.090Z","dependency_job_id":"4483c19a-6457-4b36-9107-9cc6dec764d0","html_url":"https://github.com/epignatelli/navix","commit_stats":{"total_commits":233,"total_committers":4,"mean_commits":58.25,"dds":0.01716738197424894,"last_synced_commit":"4b1d7217f19827168b9ae757e65e271a39c6b03d"},"previous_names":["epignatelli/navix"],"tags_count":64,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/epignatelli%2Fnavix","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/epignatelli%2Fnavix/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/epignatelli%2Fnavix/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/epignatelli%2Fnavix/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/epignatelli","download_url":"https://codeload.github.com/epignatelli/navix/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":247451654,"owners_count":20940939,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["deep-reinforcement-learning","deep-rl","environment","gridworld","gridworld-environment","minigrid","reinforcement-learning","rl"],"created_at":"2025-02-18T10:02:02.494Z","updated_at":"2025-04-06T08:11:50.021Z","avatar_url":"https://github.com/epignatelli.png","language":"Python","funding_links":[],"categories":["Recently Updated","Libraries"],"sub_categories":["[Feb 16, 2025](/content/2025/02/16/README.md)","New Libraries"],"readme":"\u003cdiv align=\"center\"\u003e\n\u003cimg width=150px src=\"https://github.com/epignatelli/navix/assets/26899347/4168c100-f0e6-4bae-9680-2c1a82bba8a4\" alt=\"logo\"\u003e\u003c/img\u003e\n\n# NAVIX: minigrid in JAX\n[![CI](https://github.com/epignatelli/navix/actions/workflows/CI.yml/badge.svg)](https://github.com/epignatelli/navix/actions/workflows/CI.yml)\n[![CD](https://github.com/epignatelli/navix/actions/workflows/CD.yml/badge.svg)](https://github.com/epignatelli/navix/actions/workflows/CD.yml)\n![PyPI version](https://img.shields.io/pypi/v/navix?label=PyPI\u0026color=%230099ab)\n[![arXiv](https://img.shields.io/badge/arXiv-2407.19396-b31b1b.svg?style=flat)](https://arxiv.org/abs/2407.19396)\n\n**[Quickstart](#what-is-navix)** | **[Install](#installation)** | **[Performance](#performance)** | **[Examples](#examples)** | **[Docs](https://epignatelli.com/navix)** | **[The JAX ecosystem](#jax-ecosystem-for-rl)** | **[Contribute](#join-us)** | **[Cite](#cite)**\n\n\u003c/div\u003e\n\n## What is NAVIX?\nNAVIX is a JAX-powered reimplementation of [MiniGrid](https://github.com/Farama-Foundation/Minigrid). Experiments that took \u003cins\u003e**1 week**\u003c/ins\u003e, now take \u003cins\u003e**15 minutes**\u003c/ins\u003e.    \n\n200 000x speedups compared to MiniGrid and 670 Million steps/s are not just a speed improvements. They produce a whole new paradigm that grants access to experiments that were previously impossible, e.g., those taking years to run.\n\nIt changes the game.    \nCheck out the NAVIX [performance](#performance) more in detail and the [documentation](https://epignatelli.com/navix) for more information.\n\nKey features:\n- Performance Boost: NAVIX offers \u003cins\u003e**over 1000x**\u003c/ins\u003e speed increase compared to the original Minigrid implementation, enabling faster experimentation and scaling. You can see a preliminary performance comparison [here](docs/performance.py), and a full benchmarking at [here](benchmarks/).\n- XLA Compilation: Leverage the power of XLA to optimize NAVIX computations for many accelerators. NAVIX can run on CPU, GPU, and TPU.\n- Autograd Support: Differentiate through environment transitions, opening up new possibilities such as learned world models.\n- Batched hyperparameter tuning: run thousands of experiments in parallel, enabling hyperparameter tuning at scale. Clear your doubts instantly if your algorithm doesn't work because of the hyperparameters choice.\n- It allows finally focus on the method research, and not the engineering.\n\nThe library is in active development, and we are working on adding more environments and features.\nIf you want join the development and contribute, please [open a discussion](https://github.com/epignatelli/navix/discussions/new?category=general) and let's have a chat!\n\n\n## Installation\n#### Install JAX\nFollow the official installation guide for your OS and preferred accelerator: https://github.com/google/jax#installation.\n\n#### Install NAVIX\n```bash\npip install navix\n```\n\nOr, for the latest version from source:\n```bash\npip install git+https://github.com/epignatelli/navix\n```\n\n## Performance\nNAVIX improves MiniGrid both in execution speed *and* throughput, allowing to run more than 2048 PPO agents in parallel almost 10 times faster than *a single* PPO agent in the original MiniGrid.\n\n![speedup_env](https://github.com/user-attachments/assets/b221048c-1b98-43d8-b09b-2a240412dd81)\n\nNAVIX performs 2048 × 1M/49s = 668 734 693.88 steps per second (∼ 670 Million steps/s) in batch mode,\nwhile the original Minigrid implementation performs 1M/318.01 = 3 144.65 steps per second. This\nis a speedup of over 200 000×.\n![throughput_ppo](https://github.com/user-attachments/assets/eea6e312-55b4-41c3-adb0-4207c5e78fd1)\n\n\n## Examples\nYou can view a full set of examples [here](examples/) (more coming), but here are the most common use cases.\n\n### Compiling a collection step\n```python\nimport jax\nimport navix as nx\nimport jax.numpy as jnp\n\n\ndef run(seed):\n  env = nx.make('MiniGrid-Empty-8x8-v0') # Create the environment\n  key = jax.random.PRNGKey(seed)\n  timestep = env.reset(key)\n  actions = jax.random.randint(key, (N_TIMESTEPS,), 0, env.action_space.n)\n\n  def body_fun(timestep, action):\n      timestep = env.step(action)  # Update the environment state\n      return timestep, ()\n\n  return jax.lax.scan(body_fun, timestep, actions)[0]\n\n# Compile the entire training run for maximum performance\nfinal_timestep = jax.jit(jax.vmap(run))(jnp.arange(1000))\n```\n\n### Compiling a full training run\n```python\nimport jax\nimport navix as nx\nimport jax.numpy as jnp\nfrom jax import random\n\ndef run_episode(seed, env, policy):\n    \"\"\"Simulates a single episode with a given policy\"\"\"\n    key = random.PRNGKey(seed)\n    timestep = env.reset(key)\n    done = False\n    total_reward = 0\n\n    while not done:\n        action = policy(timestep.observation)\n        timestep, reward, done, _ = env.step(action)\n        total_reward += reward\n\n    return total_reward\n\ndef train_policy(policy, num_episodes):\n    \"\"\"Trains a policy over multiple parallel episodes\"\"\"\n    envs = jax.vmap(nx.make, in_axes=0)(['MiniGrid-MultiRoom-N2-S4-v0'] * num_episodes)\n    seeds = random.split(random.PRNGKey(0), num_episodes)\n\n    # Compile the entire training loop with XLA\n    compiled_episode = jax.jit(run_episode)\n    compiled_train = jax.jit(jax.vmap(compiled_episode, in_axes=(0, 0, None)))\n\n    for _ in range(num_episodes):\n        rewards = compiled_train(seeds, envs, policy)\n        # ... Update the policy based on rewards ...\n\n# Hypothetical policy function\ndef policy(observation):\n   # ... your policy logic ...\n   return action\n\n# Start the training\ntrain_policy(policy, num_episodes=100)\n```\n\n### Backpropagation through the environment\n```python\nimport jax\nimport navix as nx\nimport jax.numpy as jnp\nfrom jax import grad\nfrom flax import struct\n\n\nclass Model(struct.PyTreeNode):\n  @nn.compact\n  def __call__(self, x):\n    # ... your NN here\n\nmodel = Model()\nenv = nx.environments.Room(16, 16, 8)\n\ndef loss(params, timestep):\n  action = jnp.asarray(0)\n  pred_obs = model.apply(timestep.observation)\n  timestep = env.step(timestep, action)\n  return jnp.square(timestep.observation - pred_obs).mean()\n\nkey = jax.random.PRNGKey(0)\ntimestep = env.reset(key)\nparams = model.init(key, timestep.observation)\n\ngradients = grad(loss)(params, timestep)\n```\n\n## JAX ecosystem for RL\nNAVIX is not alone and part of an ecosystem of JAX-powered modules for RL. Check out the following projects:\n- Environments:\n  - [Gymnax](https://github.com/RobertTLange/gymnax): a broad range of RL environments\n  - [Brax](https://github.com/google/brax): a physics engine for robotics experiments\n  - [EnvPool](https://github.com/sail-sg/envpool): a set of various batched environments\n  - [Craftax](https://github.com/MichaelTMatthews/Craftax): a JAX reimplementation of the game of [Crafter](https://github.com/danijar/crafter)\n  - [Jumanji](https://github.com/instadeepai/jumanji): another set of diverse environments\n  - [PGX](https://github.com/sotetsuk/pgx): board games commonly used for RL, such as backgammon, chess, shogi, and go\n  - [JAX-MARL](https://github.com/FLAIROx/JaxMARL): multi-agent RL environments in JAX\n  - [Xland-Minigrid](https://github.com/corl-team/xland-minigrid/): a set of JAX-reimplemented grid-world environments\n  - [Minimax](https://github.com/facebookresearch/minimax):  a JAX library for RL autocurricula with 120x faster baselines\n- Agents:\n  - [PureJaxRl](https://github.com/luchris429/purejaxrl): proposing fullly-jitten training routines\n  - [Rejax](https://github.com/keraJLi/rejax): a suite of diverse agents, among which, DDPG, DQN, PPO, SAC, TD3\n  - [Stoix](https://github.com/EdanToledo/Stoix): useful implementations of popular single-agent RL algorithms in JAX\n  - [JAX-CORL](https://github.com/nissymori/JAX-CORL): lean single-file implementations of offline RL algorithms with solid performance reports\n  - [Dopamine](https://github.com/google/dopamine): a research framework for fast prototyping of reinforcement learning algorithms\n  \n\n## Join Us!\n\nNAVIX is actively developed. If you'd like to contribute to this open-source project, we welcome your involvement! Start a discussion or open a pull request.\n\nPlease, consider starring the project if you like NAVIX!\n\n## Cite us, please!\nIf you use NAVIX please cite it as:\n\n```bibtex\n@article{pignatelli2024navix,\n  title={NAVIX: Scaling MiniGrid Environments with JAX},\n  author={Pignatelli, Eduardo and Liesen, Jarek and Lange, Robert Tjarko and Lu, Chris and Castro, Pablo Samuel and Toni, Laura},\n  journal={arXiv preprint arXiv:2407.19396},\n  year={2024}\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fepignatelli%2Fnavix","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fepignatelli%2Fnavix","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fepignatelli%2Fnavix/lists"}