Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/asem000/pytreeclass

Visualize, create, and operate on pytrees in the most intuitive way possible.
https://github.com/asem000/pytreeclass

data dataclasses deep-learning jax machine-learning pipelines pytorch pytree tensorflow

Last synced: 6 days ago
JSON representation

Visualize, create, and operate on pytrees in the most intuitive way possible.

Awesome Lists containing this project

README

        




[**Installation**](#installation)
|[**Description**](#description)
|[**Quick Example**](#quick_example)
|[**StatefulComputation**](#stateful_computation)
|[**Benchamrks**](#more)
|[**Acknowledgements**](#acknowledgements)

![Tests](https://github.com/ASEM000/pytreeclass/actions/workflows/test_default.yml/badge.svg)
![Tests](https://github.com/ASEM000/pytreeclass/actions/workflows/test_jax.yml/badge.svg)
![Tests](https://github.com/ASEM000/pytreeclass/actions/workflows/test_numpy.yml/badge.svg)
![Tests](https://github.com/ASEM000/pytreeclass/actions/workflows/test_torch.yml/badge.svg)
![pyver](https://img.shields.io/badge/python-3.8%203.9%203.10%203.11_-blue)
![codestyle](https://img.shields.io/badge/codestyle-black-black)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/ASEM000/pytreeclass/blob/main/assets/intro.ipynb)
[![Downloads](https://static.pepy.tech/badge/pytreeclass)](https://pepy.tech/project/pytreeclass)
[![codecov](https://codecov.io/gh/ASEM000/pytreeclass/branch/main/graph/badge.svg?token=TZBRMO0UQH)](https://codecov.io/gh/ASEM000/pytreeclass)
[![Documentation Status](https://readthedocs.org/projects/pytreeclass/badge/?version=latest)](https://pytreeclass.readthedocs.io/en/latest/?badge=latest)
![GitHub commit activity](https://img.shields.io/github/commit-activity/m/ASEM000/pytreeclass)
[![DOI](https://zenodo.org/badge/512717921.svg)](https://zenodo.org/badge/latestdoi/512717921)
![PyPI](https://img.shields.io/pypi/v/pytreeclass)
[![CodeFactor](https://www.codefactor.io/repository/github/asem000/pytreeclass/badge)](https://www.codefactor.io/repository/github/asem000/pytreeclass)

## πŸ› οΈ Installation

```python
pip install pytreeclass
```

**Install development version**

```python
pip install git+https://github.com/ASEM000/pytreeclass
```

## πŸ“– Description

`pytreeclass` is a JAX-compatible class builder to create and operate on stateful JAX PyTrees in a performant and intuitive way, by building on familiar concepts found in `numpy`, `dataclasses`, and others.

See [documentation](https://pytreeclass.readthedocs.io/en/latest/notebooks/getting_started.html) and [🍳 Common recipes](https://pytreeclass.readthedocs.io/en/latest/notebooks/common_recipes.html) to check if this library is a good fit for your work. _If you find the package useful consider giving it a 🌟._

## ⏩ Quick Example

```python
import jax
import jax.numpy as jnp
import pytreeclass as tc

@tc.autoinit
class Tree(tc.TreeClass):
a: float = 1.0
b: tuple[float, float] = (2.0, 3.0)
c: jax.Array = jnp.array([4.0, 5.0, 6.0])

def __call__(self, x):
return self.a + self.b[0] + self.c + x

tree = Tree()
mask = jax.tree_map(lambda x: x > 5, tree)
tree = tree\
.at["a"].set(100.0)\
.at["b"][0].set(10.0)\
.at[mask].set(100.0)

print(tree)
# Tree(a=100.0, b=(10.0, 3.0), c=[ 4. 5. 100.])

print(tc.tree_diagram(tree))
# Tree
# β”œβ”€β”€ .a=100.0
# β”œβ”€β”€ .b:tuple
# β”‚ β”œβ”€β”€ [0]=10.0
# β”‚ └── [1]=3.0
# └── .c=f32[3](ΞΌ=36.33, Οƒ=45.02, ∈[4.00,100.00])

print(tc.tree_summary(tree))
# β”Œβ”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”¬β”€β”€β”€β”€β”€β”€β”
# β”‚Name β”‚Type β”‚Countβ”‚Size β”‚
# β”œβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
# β”‚.a β”‚float β”‚1 β”‚ β”‚
# β”œβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
# β”‚.b[0]β”‚float β”‚1 β”‚ β”‚
# β”œβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
# β”‚.b[1]β”‚float β”‚1 β”‚ β”‚
# β”œβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
# β”‚.c β”‚f32[3]β”‚3 β”‚12.00Bβ”‚
# β”œβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”Όβ”€β”€β”€β”€β”€β”€β”€
# β”‚Ξ£ β”‚Tree β”‚6 β”‚12.00Bβ”‚
# β””β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”΄β”€β”€β”€β”€β”€β”€β”˜

# ** pass it to jax transformations **
# works with jit, grad, vmap, etc.

@jax.jit
@jax.grad
def sum_tree(tree: Tree, x):
return sum(tree(x))

print(sum_tree(tree, 1.0))
# Tree(a=3.0, b=(3.0, 0.0), c=[1. 1. 1.])
```

## πŸ“œ Stateful computations

[Under jax.jit jax requires states to be explicit](https://jax.readthedocs.io/en/latest/jax-101/07-state.html?highlight=state), this means that for any class instance; variables needs to be separated from the class and be passed explictly. However when using `TreeClass` no need to separate the instance variables ; instead the whole instance is passed as a state.

Using the following pattern,Updating state **functionally** can be achieved under `jax.jit`

```python
import jax
import pytreeclass as tc

class Counter(tc.TreeClass):
def __init__(self, calls: int = 0):
self.calls = calls

def increment(self):
self.calls += 1
counter = Counter() # Counter(calls=0)
```

Here, we define the update function. Since the increment method mutate the internal state, thus we need to use the functional approach to update the state by using `.at`. To achieve this we can use `.at[method_name].__call__(*args,**kwargs)`, this functional call will return the value of this call and a _new_ model instance with the update state.

```python
@jax.jit
def update(counter):
value, new_counter = counter.at["increment"]()
return new_counter

for i in range(10):
counter = update(counter)

print(counter.calls) # 10
```

## βž• Benchmarks

Benchmark flatten/unflatten compared to Flax and Equinox

Open In Colab

CPUGPU

Benchmark simple training against `flax` and `equinox`

Training simple sequential linear benchmark against `flax` and `equinox`

Num of layers
Flax/tc time
Open In Colab
Equinox/tc time
Open In Colab

10
1.427
6.671

100
1.1130
2.714

## πŸ“™ Acknowledgements

- [Lenses](https://hackage.haskell.org/package/lens)
- [Treex](https://github.com/cgarciae/treex), [Equinox](https://github.com/patrick-kidger/equinox), [tree-math](https://github.com/google/tree-math), [Flax PyTreeNode](https://github.com/google/flax/commit/291a5f65549cf4522f0de033451cd83c0d0168d9), [TensorFlow](https://www.tensorflow.org), [PyTorch](https://pytorch.org)
- [Lovely JAX](https://github.com/xl0/lovely-jax)