{"id":17059263,"url":"https://github.com/danijar/ninjax","last_synced_at":"2025-04-06T01:07:20.862Z","repository":{"id":56682080,"uuid":"495255860","full_name":"danijar/ninjax","owner":"danijar","description":"General Modules for JAX","archived":false,"fork":false,"pushed_at":"2025-02-26T22:50:49.000Z","size":102,"stargazers_count":64,"open_issues_count":3,"forks_count":2,"subscribers_count":4,"default_branch":"main","last_synced_at":"2025-03-30T00:11:06.736Z","etag":null,"topics":["deep-learning","jax"],"latest_commit_sha":null,"homepage":"https://ninjax.readthedocs.io","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/danijar.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2022-05-23T04:20:12.000Z","updated_at":"2025-02-26T22:50:53.000Z","dependencies_parsed_at":"2024-02-05T02:37:55.145Z","dependency_job_id":"a3364c2a-630a-4844-a451-e682d7af9630","html_url":"https://github.com/danijar/ninjax","commit_stats":{"total_commits":81,"total_committers":1,"mean_commits":81.0,"dds":0.0,"last_synced_commit":"8fb2dc8047e140b36fa0d4a78390083944e52e71"},"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/danijar%2Fninjax","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/danijar%2Fninjax/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/danijar%2Fninjax/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/danijar%2Fninjax/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/danijar","download_url":"https://codeload.github.com/danijar/ninjax/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":247419860,"owners_count":20936012,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["deep-learning","jax"],"created_at":"2024-10-14T10:33:41.137Z","updated_at":"2025-04-06T01:07:20.842Z","avatar_url":"https://github.com/danijar.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"[![PyPI](https://img.shields.io/pypi/v/ninjax.svg)](https://pypi.python.org/pypi/ninjax/#history)\n\n# 🥷  Ninjax: Flexible Modules for JAX\n\nNinjax is a general and practical module system for [JAX][jax]. It gives users\nfull and transparent control over updating the state of each module, bringing\nflexibility to JAX and enabling new use cases.\n\n## Overview\n\nNinjax provides a simple and general `nj.Module` class.\n\n- Modules can store state for things like model parameters, Adam momentum\n  buffer, BatchNorm statistics, recurrent state, etc.\n- Modules can read and write their state entries. For example, this allows\n  modules to have train methods, because they can update their parameters from\n  the inside.\n- Any method can initialize, read, and write state entries. This avoids the\n  need for a special `build()` method or `@compact` decorator used in Flax.\n- Ninjax makes it easy to mix and match modules from different libraries, such\n  as [Flax][flax] and [Haiku][flax].\n- Instead of PyTrees, Ninjax state is a flat `dict` that maps\n  string keys like `/net/layer1/weights` to `jnp.array`s. This makes it easy\n  to iterate over, modify, and save or load state.\n- Modules can specify typed hyperparameters using the [dataclass][dataclass]\n  syntax.\n\n[jax]: https://github.com/google/jax\n[flax]: https://github.com/google/flax\n[haiku]: https://github.com/deepmind/dm-haiku\n[dataclass]: https://docs.python.org/3/library/dataclasses.html\n\n## Installation\n\nNinjax is [a single file][file], so you can just copy it to your project\ndirectory. Or you can install the package:\n\n```\npip install ninjax\n```\n\n[file]: https://github.com/danijar/ninjax/blob/main/ninjax/ninjax.py\n\n## Quickstart\n\n```python3\nimport flax\nimport jax\nimport jax.numpy as jnp\nimport ninjax as nj\nimport optax\n\nLinear = nj.FromFlax(flax.linen.Dense)\n\n\nclass MyModel(nj.Module):\n\n  lr: float = 1e-3\n\n  def __init__(self, size):\n    self.size = size\n    # Define submodules upfront\n    self.h1 = Linear(128, name='h1')\n    self.h2 = Linear(128, name='h2')\n    self.opt = optax.adam(self.lr)\n\n  def predict(self, x):\n    x = jax.nn.relu(self.h1(x))\n    x = jax.nn.relu(self.h2(x))\n    # Define submodules inline\n    x = self.sub('h3', Linear, self.size, use_bias=False)(x)\n    # Create state entries inline\n    x += self.value('bias', jnp.zeros, self.size)\n    # Update state entries inline\n    self.write('bias', self.read('bias') + 0.1)\n    return x\n\n  def loss(self, x, y):\n    return ((self.predict(x) - y) ** 2).mean()\n\n  def train(self, x, y):\n    # Take grads wrt. to submodules or state keys\n    wrt = [self.h1, self.h2, f'{self.path}/h3', f'{self.path}/bias']\n    loss, params, grads = nj.grad(self.loss, wrt)(x, y)\n    # Update weights\n    state = self.sub('optstate', nj.Tree, self.opt.init, params)\n    updates, new_state = self.opt.update(grads, state.read(), params)\n    params = optax.apply_updates(params, updates)\n    nj.context().update(params)  # Store the new params\n    state.write(new_state)       # Store new optimizer state\n    return loss\n\n\n# Create model and example data\nmodel = MyModel(3, name='model')\nx = jnp.ones((64, 32), jnp.float32)\ny = jnp.ones((64, 3), jnp.float32)\n\n# Populate initial state from one or more functions\nstate = {}\nstate = nj.init(model.train)(state, x, y, seed=0)\nprint(state['model/bias'])\n\n# Purify for JAX transformations\ntrain = jax.jit(nj.pure(model.train))\n\n# Training loop\nfor x, y in [(x, y)] * 10:\n  state, loss = train(state, x, y)\n  print('Loss:', float(loss))\n\n# Look at the parameters\nprint(state['model/bias'])\n```\n\n## Questions\n\nIf you have a question, please [file an issue][issues].\n\n[issues]: https://github.com/danijar/ninjax/issues\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fdanijar%2Fninjax","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fdanijar%2Fninjax","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fdanijar%2Fninjax/lists"}