Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/asem000/pytreeclass
Visualize, create, and operate on pytrees in the most intuitive way possible.
https://github.com/asem000/pytreeclass
data dataclasses deep-learning jax machine-learning pipelines pytorch pytree tensorflow
Last synced: 6 days ago
JSON representation
Visualize, create, and operate on pytrees in the most intuitive way possible.
- Host: GitHub
- URL: https://github.com/asem000/pytreeclass
- Owner: ASEM000
- License: apache-2.0
- Created: 2022-07-11T10:51:14.000Z (over 2 years ago)
- Default Branch: main
- Last Pushed: 2024-10-11T11:02:07.000Z (about 1 month ago)
- Last Synced: 2024-10-31T19:37:11.733Z (13 days ago)
- Topics: data, dataclasses, deep-learning, jax, machine-learning, pipelines, pytorch, pytree, tensorflow
- Language: Python
- Homepage: https://pytreeclass.rtfd.io/en/latest
- Size: 3.2 MB
- Stars: 41
- Watchers: 1
- Forks: 2
- Open Issues: 4
-
Metadata Files:
- Readme: README.md
- Changelog: CHANGELOG.md
- License: LICENSE
Awesome Lists containing this project
README
[**Installation**](#installation)
|[**Description**](#description)
|[**Quick Example**](#quick_example)
|[**StatefulComputation**](#stateful_computation)
|[**Benchamrks**](#more)
|[**Acknowledgements**](#acknowledgements)![Tests](https://github.com/ASEM000/pytreeclass/actions/workflows/test_default.yml/badge.svg)
![Tests](https://github.com/ASEM000/pytreeclass/actions/workflows/test_jax.yml/badge.svg)
![Tests](https://github.com/ASEM000/pytreeclass/actions/workflows/test_numpy.yml/badge.svg)
![Tests](https://github.com/ASEM000/pytreeclass/actions/workflows/test_torch.yml/badge.svg)
![pyver](https://img.shields.io/badge/python-3.8%203.9%203.10%203.11_-blue)
![codestyle](https://img.shields.io/badge/codestyle-black-black)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ASEM000/pytreeclass/blob/main/assets/intro.ipynb)
[![Downloads](https://static.pepy.tech/badge/pytreeclass)](https://pepy.tech/project/pytreeclass)
[![codecov](https://codecov.io/gh/ASEM000/pytreeclass/branch/main/graph/badge.svg?token=TZBRMO0UQH)](https://codecov.io/gh/ASEM000/pytreeclass)
[![Documentation Status](https://readthedocs.org/projects/pytreeclass/badge/?version=latest)](https://pytreeclass.readthedocs.io/en/latest/?badge=latest)
![GitHub commit activity](https://img.shields.io/github/commit-activity/m/ASEM000/pytreeclass)
[![DOI](https://zenodo.org/badge/512717921.svg)](https://zenodo.org/badge/latestdoi/512717921)
![PyPI](https://img.shields.io/pypi/v/pytreeclass)
[![CodeFactor](https://www.codefactor.io/repository/github/asem000/pytreeclass/badge)](https://www.codefactor.io/repository/github/asem000/pytreeclass)```python
pip install pytreeclass
```**Install development version**
```python
pip install git+https://github.com/ASEM000/pytreeclass
````pytreeclass` is a JAX-compatible class builder to create and operate on stateful JAX PyTrees in a performant and intuitive way, by building on familiar concepts found in `numpy`, `dataclasses`, and others.
See [documentation](https://pytreeclass.readthedocs.io/en/latest/notebooks/getting_started.html) and [π³ Common recipes](https://pytreeclass.readthedocs.io/en/latest/notebooks/common_recipes.html) to check if this library is a good fit for your work. _If you find the package useful consider giving it a π._
```python
import jax
import jax.numpy as jnp
import pytreeclass as tc@tc.autoinit
class Tree(tc.TreeClass):
a: float = 1.0
b: tuple[float, float] = (2.0, 3.0)
c: jax.Array = jnp.array([4.0, 5.0, 6.0])def __call__(self, x):
return self.a + self.b[0] + self.c + xtree = Tree()
mask = jax.tree_map(lambda x: x > 5, tree)
tree = tree\
.at["a"].set(100.0)\
.at["b"][0].set(10.0)\
.at[mask].set(100.0)print(tree)
# Tree(a=100.0, b=(10.0, 3.0), c=[ 4. 5. 100.])print(tc.tree_diagram(tree))
# Tree
# βββ .a=100.0
# βββ .b:tuple
# β βββ [0]=10.0
# β βββ [1]=3.0
# βββ .c=f32[3](ΞΌ=36.33, Ο=45.02, β[4.00,100.00])print(tc.tree_summary(tree))
# βββββββ¬βββββββ¬ββββββ¬βββββββ
# βName βType βCountβSize β
# βββββββΌβββββββΌββββββΌβββββββ€
# β.a βfloat β1 β β
# βββββββΌβββββββΌββββββΌβββββββ€
# β.b[0]βfloat β1 β β
# βββββββΌβββββββΌββββββΌβββββββ€
# β.b[1]βfloat β1 β β
# βββββββΌβββββββΌββββββΌβββββββ€
# β.c βf32[3]β3 β12.00Bβ
# βββββββΌβββββββΌββββββΌβββββββ€
# βΞ£ βTree β6 β12.00Bβ
# βββββββ΄βββββββ΄ββββββ΄βββββββ# ** pass it to jax transformations **
# works with jit, grad, vmap, etc.@jax.jit
@jax.grad
def sum_tree(tree: Tree, x):
return sum(tree(x))print(sum_tree(tree, 1.0))
# Tree(a=3.0, b=(3.0, 0.0), c=[1. 1. 1.])
```[Under jax.jit jax requires states to be explicit](https://jax.readthedocs.io/en/latest/jax-101/07-state.html?highlight=state), this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using `TreeClass` no need to separate the instance variables ; instead the whole instance is passed as a state.
Using the following pattern,Updating state **functionally** can be achieved under `jax.jit`
```python
import jax
import pytreeclass as tcclass Counter(tc.TreeClass):
def __init__(self, calls: int = 0):
self.calls = callsdef increment(self):
self.calls += 1
counter = Counter() # Counter(calls=0)
```Here, we define the update function. Since the increment method mutate the internal state, thus we need to use the functional approach to update the state by using `.at`. To achieve this we can use `.at[method_name].__call__(*args,**kwargs)`, this functional call will return the value of this call and a _new_ model instance with the update state.
```python
@jax.jit
def update(counter):
value, new_counter = counter.at["increment"]()
return new_counterfor i in range(10):
counter = update(counter)print(counter.calls) # 10
```Benchmark flatten/unflatten compared to Flax and Equinox
CPUGPU
Benchmark simple training against `flax` and `equinox`
Training simple sequential linear benchmark against `flax` and `equinox`
Num of layers
Flax/tc time
Equinox/tc time
10
1.427
6.671100
1.1130
2.714- [Lenses](https://hackage.haskell.org/package/lens)
- [Treex](https://github.com/cgarciae/treex), [Equinox](https://github.com/patrick-kidger/equinox), [tree-math](https://github.com/google/tree-math), [Flax PyTreeNode](https://github.com/google/flax/commit/291a5f65549cf4522f0de033451cd83c0d0168d9), [TensorFlow](https://www.tensorflow.org), [PyTorch](https://pytorch.org)
- [Lovely JAX](https://github.com/xl0/lovely-jax)