{"id":13738320,"url":"https://github.com/cgarciae/treeo","last_synced_at":"2025-05-08T16:33:19.636Z","repository":{"id":37988992,"uuid":"408184316","full_name":"cgarciae/treeo","owner":"cgarciae","description":"A small library for creating and manipulating custom JAX Pytree classes","archived":true,"fork":false,"pushed_at":"2023-02-26T16:58:14.000Z","size":1617,"stargazers_count":59,"open_issues_count":2,"forks_count":4,"subscribers_count":4,"default_branch":"master","last_synced_at":"2024-08-04T03:12:18.003Z","etag":null,"topics":["jax","pytree"],"latest_commit_sha":null,"homepage":"https://cgarciae.github.io/treeo","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/cgarciae.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null}},"created_at":"2021-09-19T16:54:16.000Z","updated_at":"2024-01-06T18:44:16.000Z","dependencies_parsed_at":"2024-01-08T17:20:11.780Z","dependency_job_id":"cd6739dc-2c0a-4759-a118-bfc968484cee","html_url":"https://github.com/cgarciae/treeo","commit_stats":{"total_commits":142,"total_committers":5,"mean_commits":28.4,"dds":0.08450704225352113,"last_synced_commit":"b77bbfaf392a655c914b95a4d2dec6b9df9baf57"},"previous_names":[],"tags_count":14,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/cgarciae%2Ftreeo","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/cgarciae%2Ftreeo/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/cgarciae%2Ftreeo/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/cgarciae%2Ftreeo/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/cgarciae","download_url":"https://codeload.github.com/cgarciae/treeo/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":224746560,"owners_count":17363073,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["jax","pytree"],"created_at":"2024-08-03T03:02:18.414Z","updated_at":"2024-11-15T07:30:46.311Z","avatar_url":"https://github.com/cgarciae.png","language":"Python","funding_links":[],"categories":["Python"],"sub_categories":[],"readme":"_Deprecation Notice_: This library was an experiment trying to get pytree Modules working with Flax-like colletions. I'd currently recommend the following alternatives:\n* Just custom pytrees: [simple_pytree](https://github.com/cgarciae/simple-pytree)\n* Pytree module system: [equinox](https://github.com/patrick-kidger/equinox)\n* Production ready module system: [flax](https://github.com/google/flax)\n\n# Treeo \n\n_A small library for creating and manipulating custom JAX Pytree classes_\n\n* **Light-weight**: has no dependencies other than `jax`.\n* **Compatible**: Treeo `Tree` objects are compatible with any `jax` function that accepts Pytrees.\n* **Standards-based**: `treeo.field` is built on top of python's `dataclasses.field`.\n* **Flexible**: Treeo is compatible with both dataclass and non-dataclass classes.\n\nTreeo lets you easily create class-based Pytrees so your custom objects can easily interact seamlessly with JAX. Uses of Treeo can range from just creating simple simple JAX-aware utility classes to using it as the core abstraction for full-blown frameworks. Treeo was originally extracted from the core of [Treex](https://github.com/cgarciae/treex) and shares a lot in common with [flax.struct](https://flax.readthedocs.io/en/latest/flax.struct.html#module-flax.struct).\n\n[Documentation](https://cgarciae.github.io/treeo) | [User Guide](https://cgarciae.github.io/treeo/user-guide/intro)\n\n## Installation\nInstall using pip:\n```bash\npip install treeo\n```\n\n## Basics\nWith Treeo you can easily define your own custom Pytree classes by inheriting from Treeo's `Tree` class and using the `field` function to declare which fields are nodes (children) and which are static (metadata):\n\n```python\nimport treeo as to\n\n@dataclass\nclass Person(to.Tree):\n    height: jnp.array = to.field(node=True) # I am a node field!\n    name: str = to.field(node=False) # I am a static field!\n```\n`field` is just a wrapper around `dataclasses.field` so you can define your Pytrees as dataclasses, but Treeo fully supports non-dataclass classes as well. Since all `Tree` instances are Pytree they work with the various functions from the`jax` library as expected:\n\n```python\np = Person(height=jnp.array(1.8), name=\"John\")\n\n# Trees can be jitted!\njax.jit(lambda person: person)(p) # Person(height=array(1.8), name='John')\n\n# Trees can be mapped!\njax.tree_map(lambda x: 2 * x, p) # Person(height=array(3.6), name='John')\n```\n#### Kinds\nTreeo also include a kind system that lets you give semantic meaning to fields (what a field represents within your application). A kind is just a type you pass to `field` via its `kind` argument: \n\n```python\nclass Parameter: pass\nclass BatchStat: pass\n\nclass BatchNorm(to.Tree):\n    scale: jnp.ndarray = to.field(node=True, kind=Parameter)\n    mean: jnp.ndarray = to.field(node=True, kind=BatchStat)\n```\n\nKinds are very useful as a filtering mechanism via [treeo.filter](https://cgarciae.github.io/treeo/user-guide/api/filter):\n\n```python \nmodel = BatchNorm(...)\n\n# select only Parameters, mean is filtered out\nparams = to.filter(model, Parameter) # BatchNorm(scale=array(...), mean=Nothing)\n```\n`Nothing` behaves like `None` in Python, but it is a special value that is used to represent the absence of a value within Treeo.\n\nTreeo also offers the [merge](https://cgarciae.github.io/treeo/user-guide/api/merge) function which lets you rejoin filtered Trees with a logic similar to Python `dict.update` but done recursively:\n```python hl_lines=\"3\"\ndef loss_fn(params, model, ...):\n    # add traced params to model\n    model = to.merge(model, params)\n    ...\n\n# gradient only w.r.t. params\nparams = to.filter(model, Parameter) # BatchNorm(scale=array(...), mean=Nothing)\ngrads = jax.grad(loss_fn)(params, model, ...)\n```\n\nFor a more in-depth tour check out the [User Guide](https://cgarciae.github.io/treeo/user-guide/intro).\n\n## Examples\n\n### A simple Tree\n```python\nfrom dataclasses import dataclass\nimport treeo as to\n\n@dataclass\nclass Character(to.Tree):\n    position: jnp.ndarray = to.field(node=True)    # node field\n    name: str = to.field(node=False, opaque=True)  # static field\n\ncharacter = Character(position=jnp.array([0, 0]), name='Adam')\n\n# character can freely pass through jit\n@jax.jit\ndef update(character: Character, velocity, dt) -\u003e Character:\n    character.position += velocity * dt\n    return character\n\ncharacter = update(character velocity=jnp.array([1.0, 0.2]), dt=0.1)\n```\n### A Stateful Tree\n```python\nfrom dataclasses import dataclass\nimport treeo as to\n\n@dataclass\nclass Counter(to.Tree):\n    n: jnp.array = to.field(default=jnp.array(0), node=True) # node\n    step: int = to.field(default=1, node=False) # static\n\n    def inc(self):\n        self.n += self.step\n\ncounter = Counter(step=2) # Counter(n=jnp.array(0), step=2)\n\n@jax.jit\ndef update(counter: Counter):\n    counter.inc()\n    return counter\n\ncounter = update(counter) # Counter(n=jnp.array(2), step=2)\n\n# map over the tree\n```\n\n### Full Example - Linear Regression\n\n```python\nimport jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\nimport numpy as np\n\nimport treeo as to\n\n\nclass Linear(to.Tree):\n    w: jnp.ndarray = to.node()\n    b: jnp.ndarray = to.node()\n\n    def __init__(self, din, dout, key):\n        self.w = jax.random.uniform(key, shape=(din, dout))\n        self.b = jnp.zeros(shape=(dout,))\n\n    def __call__(self, x):\n        return jnp.dot(x, self.w) + self.b\n\n\n@jax.value_and_grad\ndef loss_fn(model, x, y):\n    y_pred = model(x)\n    loss = jnp.mean((y_pred - y) ** 2)\n\n    return loss\n\n\ndef sgd(param, grad):\n    return param - 0.1 * grad\n\n\n@jax.jit\ndef train_step(model, x, y):\n    loss, grads = loss_fn(model, x, y)\n    model = jax.tree_map(sgd, model, grads)\n\n    return loss, model\n\n\nx = np.random.uniform(size=(500, 1))\ny = 1.4 * x - 0.3 + np.random.normal(scale=0.1, size=(500, 1))\n\nkey = jax.random.PRNGKey(0)\nmodel = Linear(1, 1, key=key)\n\nfor step in range(1000):\n    loss, model = train_step(model, x, y)\n    if step % 100 == 0:\n        print(f\"loss: {loss:.4f}\")\n\nX_test = np.linspace(x.min(), x.max(), 100)[:, None]\ny_pred = model(X_test)\n\nplt.scatter(x, y, c=\"k\", label=\"data\")\nplt.plot(X_test, y_pred, c=\"b\", linewidth=2, label=\"prediction\")\nplt.legend()\nplt.show()\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fcgarciae%2Ftreeo","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fcgarciae%2Ftreeo","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fcgarciae%2Ftreeo/lists"}