{"id":15643032,"url":"https://github.com/brentyi/jax_dataclasses","last_synced_at":"2025-05-09T00:27:35.926Z","repository":{"id":39756431,"uuid":"372188516","full_name":"brentyi/jax_dataclasses","owner":"brentyi","description":"Pytrees + dataclasses ❤️","archived":false,"fork":false,"pushed_at":"2025-04-24T23:40:06.000Z","size":66,"stargazers_count":62,"open_issues_count":3,"forks_count":6,"subscribers_count":4,"default_branch":"main","last_synced_at":"2025-04-25T00:31:11.556Z","etag":null,"topics":["dataclasses","jax","python"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/brentyi.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2021-05-30T10:47:43.000Z","updated_at":"2025-04-24T23:37:46.000Z","dependencies_parsed_at":"2024-10-23T00:43:10.808Z","dependency_job_id":null,"html_url":"https://github.com/brentyi/jax_dataclasses","commit_stats":{"total_commits":46,"total_committers":2,"mean_commits":23.0,"dds":"0.021739130434782594","last_synced_commit":"576763858940f5ddb2bdc4fb780696be489b1c15"},"previous_names":[],"tags_count":22,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/brentyi%2Fjax_dataclasses","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/brentyi%2Fjax_dataclasses/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/brentyi%2Fjax_dataclasses/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/brentyi%2Fjax_dataclasses/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/brentyi","download_url":"https://codeload.github.com/brentyi/jax_dataclasses/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":253169135,"owners_count":21864980,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["dataclasses","jax","python"],"created_at":"2024-10-03T11:58:41.908Z","updated_at":"2025-05-09T00:27:35.882Z","avatar_url":"https://github.com/brentyi.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"## jax_dataclasses\n\n![build](https://github.com/brentyi/jax_dataclasses/workflows/build/badge.svg)\n![mypy](https://github.com/brentyi/jax_dataclasses/workflows/mypy/badge.svg?branch=main)\n![lint](https://github.com/brentyi/jax_dataclasses/workflows/lint/badge.svg)\n[![codecov](https://codecov.io/gh/brentyi/jax_dataclasses/branch/main/graph/badge.svg?token=fFSx7CeKlW)](https://codecov.io/gh/brentyi/jax_dataclasses)\n\n\u003c!-- vim-markdown-toc GFM --\u003e\n\n- [Overview](#overview)\n- [Installation](#installation)\n- [Core interface](#core-interface)\n- [Static fields](#static-fields)\n- [Mutations](#mutations)\n- [Alternatives](#alternatives)\n- [Misc](#misc)\n\n\u003c!-- vim-markdown-toc --\u003e\n\n### Overview\n\n`jax_dataclasses` provides a simple wrapper around `dataclasses.dataclass` for use in\nJAX, which enables automatic support for:\n\n- [Pytree](https://jax.readthedocs.io/en/latest/pytrees.html) registration. This\n  allows dataclasses to be used at API boundaries in JAX.\n- Serialization via `flax.serialization`.\n\nDistinguishing features include:\n\n- An annotation-based interface for marking static fields.\n- Improved ergonomics for \"model surgery\" in nested structures.\n\n### Installation\n\nIn Python \u003e=3.7:\n\n```bash\npip install jax_dataclasses\n```\n\nWe can then import:\n\n```python\nimport jax_dataclasses as jdc\n```\n\n### Core interface\n\n`jax_dataclasses` is meant to provide a drop-in replacement for\n`dataclasses.dataclass`: \u003ccode\u003ejdc.\u003cstrong\u003epytree_dataclass\u003c/strong\u003e\u003c/code\u003e has\nthe same interface as `dataclasses.dataclass`, but also registers the target\nclass as a pytree node.\n\nWe also provide several aliases:\n`jdc.[field, asdict, astuples, is_dataclass, replace]` are identical to\ntheir counterparts in the standard dataclasses library.\n\n### Static fields\n\nTo mark a field as static (in this context: constant at compile-time), we can\nwrap its type with \u003ccode\u003ejdc.\u003cstrong\u003eStatic[]\u003c/strong\u003e\u003c/code\u003e:\n\n```python\n@jdc.pytree_dataclass\nclass A:\n    a: jax.Array\n    b: jdc.Static[bool]\n```\n\nIn a pytree node, static fields will be treated as part of the treedef instead\nof as a child of the node; all fields that are not explicitly marked static\nshould contain arrays or child nodes.\n\nBonus: if you like `jdc.Static[]`, we also introduce\n\u003ccode\u003ejdc.\u003cstrong\u003ejit()\u003c/strong\u003e\u003c/code\u003e. This enables use in function\nsignatures, for example:\n\n```python\n@jdc.jit\ndef f(a: jax.Array, b: jdc.Static[bool]) -\u003e jax.Array:\n  ...\n```\n\n### Mutations\n\nAll dataclasses are automatically marked as frozen and thus immutable (even when\nno `frozen=` parameter is passed in). To make changes to nested structures\neasier, \u003ccode\u003ejdc.\u003cstrong\u003ecopy_and_mutate\u003c/strong\u003e\u003c/code\u003e (a) makes a copy of a\npytree and (b) returns a context in which any of that copy's contained\ndataclasses are temporarily mutable:\n\n```python\nimport jax\nfrom jax import numpy as jnp\nimport jax_dataclasses as jdc\n\n@jdc.pytree_dataclass\nclass Node:\n  child: jax.Array\n\nobj = Node(child=jnp.zeros(3))\n\nwith jdc.copy_and_mutate(obj) as obj_updated:\n  # Make mutations to the dataclass. This is primarily useful for nested\n  # dataclasses.\n  #\n  # Does input validation by default: if the treedef, leaf shapes, or dtypes\n  # of `obj` and `obj_updated` don't match, an AssertionError will be raised.\n  # This can be disabled with a `validate=False` argument.\n  obj_updated.child = jnp.ones(3)\n\nprint(obj)\nprint(obj_updated)\n```\n\n### Alternatives\n\nA few other solutions exist for automatically integrating dataclass-style\nobjects into pytree structures. Great ones include:\n[`chex.dataclass`](https://github.com/deepmind/chex),\n[`flax.struct`](https://github.com/google/flax), and\n[`tjax.dataclass`](https://github.com/NeilGirdhar/tjax). These all influenced\nthis library.\n\nThe main differentiators of `jax_dataclasses` are:\n\n- **Static analysis support.** `tjax` has a custom mypy plugin to enable type\n  checking, but isn't supported by other tools. `flax.struct` implements the\n  [`dataclass_transform`](https://github.com/microsoft/pyright/blob/main/specs/dataclass_transforms.md)\n  spec proposed by pyright, but isn't supported by other tools. Because\n  `@jdc.pytree_dataclass` has the same API as `@dataclasses.dataclass`, it can\n  include pytree registration behavior at runtime while being treated as the\n  standard decorator during static analysis. This means that all static\n  checkers, language servers, and autocomplete engines that support the standard\n  `dataclasses` library should work out of the box with `jax_dataclasses`.\n\n- **Nested dataclasses.** Making replacements/modifications in deeply nested\n  dataclasses can be really frustrating. The three alternatives all introduce a\n  `.replace(self, ...)` method to dataclasses that's a bit more convenient than\n  the traditional `dataclasses.replace(obj, ...)` API for shallow changes, but\n  still becomes really cumbersome to use when dataclasses are nested.\n  `jdc.copy_and_mutate()` is introduced to address this.\n\n- **Static field support.** Parameters that should not be traced in JAX should\n  be marked as static. This is supported in `flax`, `tjax`, and\n  `jax_dataclasses`, but not `chex`.\n\n- **Serialization.** When working with `flax`, being able to serialize\n  dataclasses is really handy. This is supported in `flax.struct` (naturally)\n  and `jax_dataclasses`, but not `chex` or `tjax`.\n\nYou can also eschew the dataclass-style interface entirely;\n[see how brax registers pytrees](https://github.com/google/brax/blob/730e05d4af58eada5b49a44e849107d76e386b9a/brax/pytree.py).\nThis is a reasonable thing to prefer: it requires some floating strings and\nbreaks things that I care about but you may not (like immutability and\n`__post_init__`), but gives more flexibility with custom `__init__` methods.\n\n### Misc\n\n`jax_dataclasses` was originally written for and factored out of\n[jaxfg](http://github.com/brentyi/jaxfg), where\n[Nick Heppert](https://github.com/SuperN1ck) provided valuable feedback.\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fbrentyi%2Fjax_dataclasses","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fbrentyi%2Fjax_dataclasses","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fbrentyi%2Fjax_dataclasses/lists"}