{"id":17143836,"url":"https://github.com/asem000/pytreeclass","last_synced_at":"2025-04-07T07:05:58.108Z","repository":{"id":44464600,"uuid":"512717921","full_name":"ASEM000/pytreeclass","owner":"ASEM000","description":"Visualize, create, and operate on pytrees in the most intuitive way possible.","archived":false,"fork":false,"pushed_at":"2025-01-11T19:12:38.000Z","size":3366,"stargazers_count":45,"open_issues_count":4,"forks_count":2,"subscribers_count":1,"default_branch":"main","last_synced_at":"2025-03-31T06:04:37.235Z","etag":null,"topics":["data","dataclasses","deep-learning","jax","machine-learning","pipelines","pytorch","pytree","tensorflow"],"latest_commit_sha":null,"homepage":"https://pytreeclass.rtfd.io/en/latest","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/ASEM000.png","metadata":{"files":{"readme":"README.md","changelog":"CHANGELOG.md","contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2022-07-11T10:51:14.000Z","updated_at":"2025-03-28T12:54:09.000Z","dependencies_parsed_at":"2024-10-14T20:42:57.993Z","dependency_job_id":"d8262a08-3fc5-4030-9faf-4508d238b4ec","html_url":"https://github.com/ASEM000/pytreeclass","commit_stats":{"total_commits":1047,"total_committers":3,"mean_commits":349.0,"dds":0.05635148042024829,"last_synced_commit":"8c117f302744cea00297df4952a552026007f536"},"previous_names":[],"tags_count":63,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ASEM000%2Fpytreeclass","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ASEM000%2Fpytreeclass/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ASEM000%2Fpytreeclass/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/ASEM000%2Fpytreeclass/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/ASEM000","download_url":"https://codeload.github.com/ASEM000/pytreeclass/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":247608150,"owners_count":20965952,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["data","dataclasses","deep-learning","jax","machine-learning","pipelines","pytorch","pytree","tensorflow"],"created_at":"2024-10-14T20:42:25.556Z","updated_at":"2025-04-07T07:05:58.066Z","avatar_url":"https://github.com/ASEM000.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"\u003c!-- \u003ch1 align=\"center\" style=\"font-family:Monospace\" \u003ePy🌲Class\u003c/h1\u003e --\u003e\n\u003ch5 align=\"center\"\u003e\n\u003cimg width=\"250px\" src=\"https://github.com/ASEM000/pytreeclass/assets/48389287/95e879f2-69d9-420b-bb64-012fa0b4eeb8\"\u003e \u003cbr\u003e\n\n\u003cbr\u003e\n\n[**Installation**](#installation)\n|[**Description**](#description)\n|[**Quick Example**](#quick_example)\n|[**StatefulComputation**](#stateful_computation)\n|[**Benchamrks**](#more)\n|[**Acknowledgements**](#acknowledgements)\n\n![Tests](https://github.com/ASEM000/pytreeclass/actions/workflows/test_default.yml/badge.svg)\n![Tests](https://github.com/ASEM000/pytreeclass/actions/workflows/test_jax.yml/badge.svg)\n![Tests](https://github.com/ASEM000/pytreeclass/actions/workflows/test_numpy.yml/badge.svg)\n![Tests](https://github.com/ASEM000/pytreeclass/actions/workflows/test_torch.yml/badge.svg)\n![pyver](https://img.shields.io/badge/python-3.8%203.9%203.10%203.11_-blue)\n![codestyle](https://img.shields.io/badge/codestyle-black-black)\n[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ASEM000/pytreeclass/blob/main/assets/intro.ipynb)\n[![Downloads](https://static.pepy.tech/badge/pytreeclass)](https://pepy.tech/project/pytreeclass)\n[![codecov](https://codecov.io/gh/ASEM000/pytreeclass/branch/main/graph/badge.svg?token=TZBRMO0UQH)](https://codecov.io/gh/ASEM000/pytreeclass)\n[![Documentation Status](https://readthedocs.org/projects/pytreeclass/badge/?version=latest)](https://pytreeclass.readthedocs.io/en/latest/?badge=latest)\n![GitHub commit activity](https://img.shields.io/github/commit-activity/m/ASEM000/pytreeclass)\n[![DOI](https://zenodo.org/badge/512717921.svg)](https://zenodo.org/badge/latestdoi/512717921)\n![PyPI](https://img.shields.io/pypi/v/pytreeclass)\n[![CodeFactor](https://www.codefactor.io/repository/github/asem000/pytreeclass/badge)](https://www.codefactor.io/repository/github/asem000/pytreeclass)\n\n\u003c/h5\u003e\n\n## 🛠️ Installation\u003ca id=\"installation\"\u003e\u003c/a\u003e\n\n```python\npip install pytreeclass\n```\n\n**Install development version**\n\n```python\npip install git+https://github.com/ASEM000/pytreeclass\n```\n\n## 📖 Description\u003ca id=\"description\"\u003e\u003c/a\u003e\n\n`pytreeclass` is a JAX-compatible class builder to create and operate on stateful JAX PyTrees in a performant and intuitive way, by building on familiar concepts found in `numpy`, `dataclasses`, and others.\n\nSee [documentation](https://pytreeclass.readthedocs.io/en/latest/notebooks/getting_started.html) and [🍳 Common recipes](https://pytreeclass.readthedocs.io/en/latest/notebooks/common_recipes.html) to check if this library is a good fit for your work. _If you find the package useful consider giving it a 🌟._\n\n## ⏩ Quick Example \u003ca id=\"quick_example\"\u003e\n\n\u003cdiv align=\"center\"\u003e\n\u003ctable\u003e\n\u003ctr\u003e\u003ctd align=\"center\"\u003e\u003c/td\u003e\u003c/tr\u003e\n\u003ctr\u003e\n\u003ctd\u003e\n\n```python\nimport jax\nimport jax.numpy as jnp\nimport pytreeclass as tc\n\n@tc.autoinit\nclass Tree(tc.TreeClass):\n    a: float = 1.0\n    b: tuple[float, float] = (2.0, 3.0)\n    c: jax.Array = jnp.array([4.0, 5.0, 6.0])\n\n    def __call__(self, x):\n        return self.a + self.b[0] + self.c + x\n\n\ntree = Tree()\nmask = jax.tree_map(lambda x: x \u003e 5, tree)\ntree = tree\\\n       .at[\"a\"].set(100.0)\\\n       .at[\"b\"][0].set(10.0)\\\n       .at[mask].set(100.0)\n\nprint(tree)\n# Tree(a=100.0, b=(10.0, 3.0), c=[  4.   5. 100.])\n\nprint(tc.tree_diagram(tree))\n# Tree\n# ├── .a=100.0\n# ├── .b:tuple\n# │   ├── [0]=10.0\n# │   └── [1]=3.0\n# └── .c=f32[3](μ=36.33, σ=45.02, ∈[4.00,100.00])\n\nprint(tc.tree_summary(tree))\n# ┌─────┬──────┬─────┬──────┐\n# │Name │Type  │Count│Size  │\n# ├─────┼──────┼─────┼──────┤\n# │.a   │float │1    │      │\n# ├─────┼──────┼─────┼──────┤\n# │.b[0]│float │1    │      │\n# ├─────┼──────┼─────┼──────┤\n# │.b[1]│float │1    │      │\n# ├─────┼──────┼─────┼──────┤\n# │.c   │f32[3]│3    │12.00B│\n# ├─────┼──────┼─────┼──────┤\n# │Σ    │Tree  │6    │12.00B│\n# └─────┴──────┴─────┴──────┘\n\n# ** pass it to jax transformations **\n# works with jit, grad, vmap, etc.\n\n@jax.jit\n@jax.grad\ndef sum_tree(tree: Tree, x):\n    return sum(tree(x))\n\nprint(sum_tree(tree, 1.0))\n# Tree(a=3.0, b=(3.0, 0.0), c=[1. 1. 1.])\n```\n\n\u003c/td\u003e\n\n\u003c/tr\u003e\n\u003c/table\u003e\n\u003c/div\u003e\n\n## 📜 Stateful computations\u003ca id=\"stateful_computation\"\u003e\u003c/a\u003e\n\n[Under jax.jit jax requires states to be explicit](https://jax.readthedocs.io/en/latest/jax-101/07-state.html?highlight=state), this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using `TreeClass` no need to separate the instance variables ; instead the whole instance is passed as a state.\n\nUsing the following pattern,Updating state **functionally** can be achieved under `jax.jit`\n\n\u003cdiv align=\"center\"\u003e\n\u003ctable\u003e\n\u003ctr\u003e\u003ctd align=\"center\"\u003e\u003c/td\u003e\u003c/tr\u003e\n\u003ctr\u003e\n\u003ctd\u003e\n\n```python\nimport jax\nimport pytreeclass as tc\n\nclass Counter(tc.TreeClass):\n    def __init__(self, calls: int = 0):\n        self.calls = calls\n\n    def increment(self):\n        self.calls += 1\ncounter = Counter() # Counter(calls=0)\n```\n\n\u003c/td\u003e\n\n\u003c/tr\u003e\n\u003c/table\u003e\n\u003c/div\u003e\n\nHere, we define the update function. Since the increment method mutate the internal state, thus we need to use the functional approach to update the state by using `.at`. To achieve this we can use `.at[method_name].__call__(*args,**kwargs)`, this functional call will return the value of this call and a _new_ model instance with the update state.\n\n\u003cdiv align=\"center\"\u003e\n\u003ctable\u003e\n\u003ctr\u003e\u003ctd align=\"center\"\u003e\u003c/td\u003e\u003c/tr\u003e\n\u003ctr\u003e\n\u003ctd\u003e\n\n```python\n@jax.jit\ndef update(counter):\n    value, new_counter = counter.at[\"increment\"]()\n    return new_counter\n\nfor i in range(10):\n    counter = update(counter)\n\nprint(counter.calls) # 10\n```\n\n\u003c/td\u003e\n\n\u003c/tr\u003e\n\u003c/table\u003e\n\u003c/div\u003e\n\n\u003c/details\u003e\n\n## ➕ Benchmarks\u003ca id=\"more\"\u003e\u003c/a\u003e\n\n\u003cdetails\u003e\n\u003csummary\u003eBenchmark flatten/unflatten compared to Flax and Equinox \u003c/summary\u003e\n\n\u003ca href=\"https://colab.research.google.com/github/ASEM000/pytreeclass/blob/main/assets/benchmark_flatten_unflatten.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\n\n\u003ctable\u003e\n\n\u003ctr\u003e\u003ctd align=\"center\"\u003eCPU\u003c/td\u003e\u003ctd align=\"center\"\u003eGPU\u003c/td\u003e\u003c/tr\u003e\n\n\u003ctr\u003e\n\n\u003ctd\u003e\u003cimg src='assets/benchmark_cpu.png'\u003e\u003c/td\u003e\n\n\u003c/tr\u003e\n\n\u003c/table\u003e\n\n\u003c/details\u003e\n\n\u003cdetails\u003e\n\n\u003csummary\u003eBenchmark simple training against `flax` and `equinox` \u003c/summary\u003e\n\nTraining simple sequential linear benchmark against `flax` and `equinox`\n\n\u003ctable\u003e\n\n\u003ctr\u003e\n\u003ctd align=\"center\"\u003eNum of layers\u003c/td\u003e\n\u003ctd align=\"center\"\u003eFlax/tc time\u003cbr\u003e\u003ca href=\"https://colab.research.google.com/github/ASEM000/pytreeclass/blob/main/assets/benchmark_nn_training_flax.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\u003c/td\u003e\n\u003ctd align=\"center\"\u003eEquinox/tc time\u003cbr\u003e \u003ca href=\"https://colab.research.google.com/github/ASEM000/pytreeclass/blob/main/assets/benchmark_nn_training_equinox.ipynb\" target=\"_parent\"\u003e\u003cimg src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/\u003e\u003c/a\u003e\u003c/td\u003e\n\u003c/tr\u003e\n\n\u003ctr\u003e\n\u003ctd align=\"center\"\u003e10\u003c/td\u003e\n\u003ctd align=\"center\"\u003e1.427\u003c/td\u003e\n\u003ctd align=\"center\"\u003e6.671\u003c/td\u003e\n\u003c/tr\u003e\n\n\u003ctr\u003e\n\u003ctd align=\"center\"\u003e100\u003c/td\u003e\n\u003ctd align=\"center\"\u003e1.1130\u003c/td\u003e\n\u003ctd align=\"center\"\u003e2.714\u003c/td\u003e\n\u003c/tr\u003e\n\n\u003c/table\u003e\n\n\u003c/details\u003e\n\n## 📙 Acknowledgements\u003ca id=\"acknowledgements\"\u003e\u003c/a\u003e\n\n- [Lenses](https://hackage.haskell.org/package/lens)\n- [Treex](https://github.com/cgarciae/treex), [Equinox](https://github.com/patrick-kidger/equinox), [tree-math](https://github.com/google/tree-math), [Flax PyTreeNode](https://github.com/google/flax/commit/291a5f65549cf4522f0de033451cd83c0d0168d9), [TensorFlow](https://www.tensorflow.org), [PyTorch](https://pytorch.org)\n- [Lovely JAX](https://github.com/xl0/lovely-jax)\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fasem000%2Fpytreeclass","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fasem000%2Fpytreeclass","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fasem000%2Fpytreeclass/lists"}