{"id":18553694,"url":"https://github.com/daskol/yax","last_synced_at":"2025-07-08T11:07:03.112Z","repository":{"id":260002030,"uuid":"876315244","full_name":"daskol/yax","owner":"daskol","description":"Yet Another X: JAX/FLAX module tracing, modification, and evaluation.","archived":false,"fork":false,"pushed_at":"2024-12-09T18:48:01.000Z","size":280,"stargazers_count":0,"open_issues_count":1,"forks_count":0,"subscribers_count":1,"default_branch":"main","last_synced_at":"2025-05-15T11:50:35.835Z","etag":null,"topics":["flax","jax","mox","yax"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/daskol.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":"CITATION.cff","codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null}},"created_at":"2024-10-21T19:00:21.000Z","updated_at":"2024-12-09T18:48:00.000Z","dependencies_parsed_at":"2025-04-12T15:50:30.342Z","dependency_job_id":null,"html_url":"https://github.com/daskol/yax","commit_stats":null,"previous_names":["daskol/yax"],"tags_count":1,"template":false,"template_full_name":null,"purl":"pkg:github/daskol/yax","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/daskol%2Fyax","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/daskol%2Fyax/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/daskol%2Fyax/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/daskol%2Fyax/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/daskol","download_url":"https://codeload.github.com/daskol/yax/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/daskol%2Fyax/sbom","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":264257676,"owners_count":23580469,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["flax","jax","mox","yax"],"created_at":"2024-11-06T21:17:58.128Z","updated_at":"2025-07-08T11:07:03.070Z","avatar_url":"https://github.com/daskol.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"![Linting and testing][1]\n![Nightly][2]\n\n[1]: https://github.com/daskol/yax/actions/workflows/on-push.yml/badge.svg\n[2]: https://github.com/daskol/yax/actions/workflows/on-schedule.yml/badge.svg\n\n# YAX\n\n*Yet Another X: JAX/FLAX module tracing, modification, and evaluation.*\n\n## Overview\n\nDeep learning frameworks like PyTorch, Keras, and JAX/FLAX usually provide a\n\"module-level\" API, which abstracts a layer—an architectural unit in a neural\nnetwork. While modules are descriptive and easy to use, they can sometimes be\ninconvenient to work with programmatically. Specifically, it is challenging to\nmodify model architecture on the fly, though changing weight structures\ndynamically is not as difficult. So, why can't we work with modules in the same\nflexible way?\n\nYAX is a library within the JAX/FLAX ecosystem for building, evaluating, and\nmodifying the intermediate representation of a neural network's modular\nstructure. Modular structures are represented with the help of MoX, a Module\neXpression, which is an extension of JAX expressions (Jaxpr). MoX is pronounced\nas ∗[mokh]∗ and means \"moss\" in Russian.\n\n```bash\npip install git+https://github.com/daskol/yax.git\n```\n\n## Usage\n\nModule expressions (MoX) are extremely useful in certain situations. For\nexample, they enable the application of custom LoRA-like adapters or model\nperformance optimizations, such as quantized gradient activation functions (see\n[fewbit][4]). We've briefly discussed what YAX/MoX can accomplish, and we’ll\nuse the ResBlock below for further demonstrations.\n\n```python\nimport flax.linen as nn\nimport yax\n\nclass ResBlock(nn.Module):\n    @nn.compact\n    def __call__(self, xs):\n        return xs + nn.Dense(10)(xs)\n\nmod = ResBlock()\nbatch = jnp.empty(1, 10)\nparams = jax.jit(mod.init)(jax.random.PRNGKey(42), batch)\n```\n\n**Tracing** First, we need to build a module representation (also known as\nMoX). This can be done in a similar way to creating a Jaxpr (see\n`jax.make_jaxpr`).\n\n```python\nmox = yax.make_mox(mod.apply)(params, batch)\nprint(mox)\n```\n\nPretty printing is is not very pretty for MoX at the moment but it will look\nlike the following. Also, we have implemented serialization to XML and YSON\n(see Serialization section).\n\n```jaxpr\n{ lambda ; a:f32[10] b:f32[10,10] c:f32[1,10]. let\n  d:f32[1,10] = module_call {\n    lambda ; a:f32[10] b:f32[10,10] c:f32[1,10]. let\n      d:f32[1,10] = dot_general[dimension_numbers=(([1], [0]), ([], []))] c b\n      e:f32[1,10] = reshape[dimensions=None new_sizes=(1, 10)] a\n      f:f32[1,10] = add d e\n    in (f,) }\n  e:f32[1,10] = add d a\n  in (e,)}\n```\n\n**Evaluation** MoX can be evaluated similarly to Jaxpr, but the most important\nfeature is that `yax.eval_mox` can be composed with common JAX transformations,\nas shown below.\n\n```python\ndef apply(params, batch):\n    return yax.eval_mox(mox, params, input_batch)\n\n_ = apply(params, batch)  # Greedy evaluation.\n_ = jax.jit(apply)(params, batch)  # JIT-compiled execution.\n```\n\n**Querying** MoX provides tools for model exploration and examination.\nSpecifically, MoX can help answer questions like: \"What `nn.Dense` modules have\n10 features?\"\n\n```python\nmodules: Sequence[yax.Mox] = yax.query('//module_call[@features=10]', mox)\n```\n\nWe use XPath (the familiar XML Path expression language) for writing queries.\nXPath is a concise and convenient way to express search conditions. In fact,\nthe module tree can be represented similarly to a DOM structure, which\neffectively models the nested structure of a neural network as well as the\nmodule attributes in its internal nodes.\n\n**Modification** With such an expressive query language, modifying an original\nmodel on the fly becomes easy. For example, one can replace all ReLU activation\nfunctions with GELU or substitute all `nn.Dense` layers with LoRA adapters.\n\n```python\n# Replace ReLU with GELU\ngelu_mox = yax.make_mox(nn.gelu)(inputs)\nmodified_mox = yax.sub('//pjit[@name=\"relu\"]', gelu_mox, mox)\n\n# Apply LoRA-adapters to all fully-connected layers.\nlora_mox = yax.make_mox(lora.apply)(params, inputs)\nmodified_mox = yax.sub('//module_call[@type=\"Dense\"]', lora_mox, mox)\n```\n\n[4]: https://proceedings.mlr.press/v202/novikov23a.html\n\n## Module Expression (MoX)\n\n### XML\n\nThe funniest part about MoX is that it can be serialized to XML. Hardly anyone\nuses XML nowadays outside the Java ecosystem and some legacy projects. However,\nXML is actually a good and even appropriate serialization format.\n\n```xml\n\u003cmodule_call type=\"flax.nn.Dense\" name=\"Dense_0\" features=\"10\"\u003e\n  \u003cinput type=\"fp32[10]\"\u003ea\u003c/input\u003e\n  \u003cinput type=\"fp32[10,10]\"\u003eb\u003c/input\u003e\n  \u003cinput type=\"fp32[10]\"\u003ec\u003c/input\u003e\n  \u003cdot_general dimension_numbers=\"(([0], [0]), ([], []))\"\u003e\n    \u003cinput type=\"fp32[10,10]\"\u003eb\u003c/input\u003e\n    \u003cinput type=\"fp32[10]\"\u003ec\u003c/input\u003e\n    \u003coutput type=\"fp32[10,10]\"\u003ed\u003c/output\u003e\n  \u003c/dot_general\u003e\n  \u003cpjit\n    jaxpr=\"{ lambda ; a:f32[10], b:f32[10]. let c:f32[10] = add a b in (c,) }\"\u003e\n    \u003cinput type=\"fp32[10]\"\u003ed\u003c/input\u003e\n    \u003cinput type=\"fp32[10]\"\u003ea\u003c/input\u003e\n    \u003coutput type=\"fp32[10]\"\u003ee\u003c/output\u003e\n  \u003c/pjit\u003e\n  \u003coutputs type=\"fp32[10]\"\u003ee\u003c/outputs\u003e\n\u003c/module_call\u003e\n```\n\n### YSON\n\n[YSON][1] stands for Yandex Serialization Object Notation. It is a\nserialization format similar to JSON due to its compact notation but is more\nexpressive. In terms of representational expressiveness, YSON is comparable to\nXML.\n\n```yson\n\u003cprimitive=\"module_call\";\n type=\"flax.nn.Dense\"; name=\"Dense_0\"; features=10;\n inputs={a=\"fp32[10]\"; b=\"fp32[10,10]\"; c=\"fp32[10]};\n outputs={e=\"fp32[10]\"}\u003e[\n  \u003cprimitive=\"dot_general\";\n   dimension_numbers=\"[[[0], [0]], [[], []]]\";\n   inputs={с=\"fp32[10]\"; b=\"fp32[10,10]};\n   outputs={d=\"fp32[10]\"}\u003e#;\n  \u003cprimitive=\"pjit\";\n   inputs={d=\"fp32[10]\"; a=\"fp32[10]\"};\n   outputs={e=\"fp32[10]\"};\n   jaxpr=\"{ lambda ; a:f32[10], b:f32[10]. let c:f32[10] = add a b in (c,) }\";\n  \u003e#;\n]\n```\n\n[1]: https://ytsaurus.tech/docs/en/user-guide/storage/yson\n[2]: https://msgpack.org/\n[3]: https://protobuf.dev/\n\n### Limitations\n\nSubstitution requires the preservation of some invariants.\n\n- Inputs and outputs are reused.\n- New outputs are prohibited for now.\n- New inputs are propagated to root node. There is a difference between Jaxpr\n  (leaf) and Mox (inode).\n\n  - \\[Jaxpr\\] New inputs are append to all parents.\n  - \\[MoX\\] Inernal node have two kind of input parameters: plain inputs and\n    weight params. FLAX requires weight params to be the first input parameter.\n    Thus old subtree should be updated with the new one.\n\n  In order to update input parameters, we should update `in_tree` as well.\n  Similarly, update to weight params requires update to `var_tree`. Note that\n  inputs/params handling for root node differs since params are passed\n  explicitely while for all internal expressions params comprises closure\n  context. Surely, any modification of `in_tree` or `var_tree` requires update\n  of input symbols.\n\n  Note, the all parent should be marks as ephemeral. Also, inputs and outputs\n  of a replacement should be type checked agains its predcessors and successors\n  respectively.\n\n  ```python\n  def substitute(parents, expr):\n    for parent in reversed(parents):\n      update_param_tree(parent, expr)\n  ```\n- Compositionality with `jax.scan`, `jax.vmap`, and `jax.pmap` is not verified.\n- Pretty printing of module expressions is not available for now.\n\n# Container\n\n```shell\ndocker pull ghcr.io/daskol/yax\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fdaskol%2Fyax","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fdaskol%2Fyax","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fdaskol%2Fyax/lists"}