{"id":18676237,"url":"https://github.com/cgarciae/simple-pytree","last_synced_at":"2025-05-09T02:59:44.882Z","repository":{"id":74614553,"uuid":"603809378","full_name":"cgarciae/simple-pytree","owner":"cgarciae","description":"A dead simple Python package for creating custom JAX pytree objects","archived":false,"fork":false,"pushed_at":"2024-11-14T18:03:09.000Z","size":113,"stargazers_count":63,"open_issues_count":1,"forks_count":2,"subscribers_count":3,"default_branch":"main","last_synced_at":"2025-05-09T02:59:39.426Z","etag":null,"topics":["jax","python"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/cgarciae.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":".github/FUNDING.yml","license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null},"funding":{"github":["cgarciae"]}},"created_at":"2023-02-19T16:31:01.000Z","updated_at":"2025-03-01T20:25:40.000Z","dependencies_parsed_at":null,"dependency_job_id":"b1ba62ba-b606-4780-858c-a528a3fa6e54","html_url":"https://github.com/cgarciae/simple-pytree","commit_stats":null,"previous_names":[],"tags_count":7,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/cgarciae%2Fsimple-pytree","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/cgarciae%2Fsimple-pytree/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/cgarciae%2Fsimple-pytree/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/cgarciae%2Fsimple-pytree/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/cgarciae","download_url":"https://codeload.github.com/cgarciae/simple-pytree/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":253181418,"owners_count":21866991,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["jax","python"],"created_at":"2024-11-07T09:28:24.474Z","updated_at":"2025-05-09T02:59:44.868Z","avatar_url":"https://github.com/cgarciae.png","language":"Python","funding_links":["https://github.com/sponsors/cgarciae"],"categories":[],"sub_categories":[],"readme":"\n\u003c!-- codecov badge --\u003e\n[![codecov](https://codecov.io/gh/cgarciae/simple-pytree/branch/main/graph/badge.svg?token=3IKEUAU3C8)](https://codecov.io/gh/cgarciae/simple-pytree)\n\n\n# Simple Pytree\n\nA _dead simple_ Python package for creating custom JAX pytree objects.\n\n* Strives to be minimal, the implementation is just ~100 lines of code\n* Has no dependencies other than JAX\n* Its compatible with both `dataclasses` and regular classes\n* It has no intention of supporting Neural Network use cases (e.g. partitioning)\n\n## Installation\n\n```bash\npip install simple-pytree\n```\n\n## Usage\n\n```python\nimport jax\nfrom simple_pytree import Pytree\n\nclass Foo(Pytree):\n    def __init__(self, x, y):\n        self.x = x\n        self.y = y\n\nfoo = Foo(1, 2)\nfoo = jax.tree_map(lambda x: -x, foo)\n\nassert foo.x == -1 and foo.y == -2\n```\n\n### Static fields\nYou can mark fields as static by assigning `static_field()` to a class attribute with the same name \nas the instance attribute:\n\n```python\nimport jax\nfrom simple_pytree import Pytree, static_field\n\nclass Foo(Pytree):\n    y = static_field()\n    \n    def __init__(self, x, y):\n        self.x = x\n        self.y = y\n\nfoo = Foo(1, 2)\nfoo = jax.tree_map(lambda x: -x, foo) # y is not modified\n\nassert foo.x == -1 and foo.y == 2\n```\n\nStatic fields are not included in the pytree leaves, they\nare passed as pytree metadata instead.\n\n### Dataclasses\n`simple_pytree` provides a `dataclass` decorator you can use with classes\nthat contain `static_field`s:\n\n```python\nimport jax\nfrom simple_pytree import Pytree, dataclass, static_field\n\n@dataclass\nclass Foo(Pytree):\n    x: int\n    y: int = static_field(default=2)\n    \nfoo = Foo(1)\nfoo = jax.tree_map(lambda x: -x, foo) # y is not modified\n\nassert foo.x == -1 and foo.y == 2\n```\n`simple_pytree.dataclass` is just a wrapper around `dataclasses.dataclass` but\nwhen used static analysis tools and IDEs will understand that `static_field` is a \nfield specifier just like `dataclasses.field`.\n\n### Mutability\n`Pytree` objects are immutable by default after `__init__`:\n\n```python\nfrom simple_pytree import Pytree, static_field\n\nclass Foo(Pytree):\n    y = static_field()\n    \n    def __init__(self, x, y):\n        self.x = x\n        self.y = y\n\nfoo = Foo(1, 2)\nfoo.x = 3 # AttributeError\n```\nIf you want to make them mutable, you can use the `mutable` argument in class definition:\n\n```python\nfrom simple_pytree import Pytree, static_field\n\nclass Foo(Pytree, mutable=True):\n    y = static_field()\n    \n    def __init__(self, x, y):\n        self.x = x\n        self.y = y\n\nfoo = Foo(1, 2)\nfoo.x = 3 # OK\n```\n\n### Replacing fields\n\nIf you want to make a copy of a `Pytree` object with some fields modified, you can use the `.replace()` method:\n\n```python\nfrom simple_pytree import Pytree, static_field\n\nclass Foo(Pytree):\n    y = static_field()\n    \n    def __init__(self, x, y):\n        self.x = x\n        self.y = y\n\nfoo = Foo(1, 2)\nfoo = foo.replace(x=10)\n\nassert foo.x == 10 and foo.y == 2\n```\n\n`replace` works for both mutable and immutable `Pytree` objects. If the class\nis a `dataclass`, `replace` internally use `dataclasses.replace`.\n\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fcgarciae%2Fsimple-pytree","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fcgarciae%2Fsimple-pytree","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fcgarciae%2Fsimple-pytree/lists"}