Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/cgarciae/simple-pytree
A dead simple Python package for creating custom JAX pytree objects
https://github.com/cgarciae/simple-pytree
jax python
Last synced: 7 days ago
JSON representation
A dead simple Python package for creating custom JAX pytree objects
- Host: GitHub
- URL: https://github.com/cgarciae/simple-pytree
- Owner: cgarciae
- License: mit
- Created: 2023-02-19T16:31:01.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2024-06-28T17:53:49.000Z (5 months ago)
- Last Synced: 2024-06-28T19:08:09.082Z (5 months ago)
- Topics: jax, python
- Language: Python
- Homepage:
- Size: 108 KB
- Stars: 61
- Watchers: 4
- Forks: 2
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- Funding: .github/FUNDING.yml
- License: LICENSE
Awesome Lists containing this project
README
[![codecov](https://codecov.io/gh/cgarciae/simple-pytree/branch/main/graph/badge.svg?token=3IKEUAU3C8)](https://codecov.io/gh/cgarciae/simple-pytree)
# Simple Pytree
A _dead simple_ Python package for creating custom JAX pytree objects.
* Strives to be minimal, the implementation is just ~100 lines of code
* Has no dependencies other than JAX
* Its compatible with both `dataclasses` and regular classes
* It has no intention of supporting Neural Network use cases (e.g. partitioning)## Installation
```bash
pip install simple-pytree
```## Usage
```python
import jax
from simple_pytree import Pytreeclass Foo(Pytree):
def __init__(self, x, y):
self.x = x
self.y = yfoo = Foo(1, 2)
foo = jax.tree_map(lambda x: -x, foo)assert foo.x == -1 and foo.y == -2
```### Static fields
You can mark fields as static by assigning `static_field()` to a class attribute with the same name
as the instance attribute:```python
import jax
from simple_pytree import Pytree, static_fieldclass Foo(Pytree):
y = static_field()
def __init__(self, x, y):
self.x = x
self.y = yfoo = Foo(1, 2)
foo = jax.tree_map(lambda x: -x, foo) # y is not modifiedassert foo.x == -1 and foo.y == 2
```Static fields are not included in the pytree leaves, they
are passed as pytree metadata instead.### Dataclasses
`simple_pytree` provides a `dataclass` decorator you can use with classes
that contain `static_field`s:```python
import jax
from simple_pytree import Pytree, dataclass, static_field@dataclass
class Foo(Pytree):
x: int
y: int = static_field(default=2)
foo = Foo(1)
foo = jax.tree_map(lambda x: -x, foo) # y is not modifiedassert foo.x == -1 and foo.y == 2
```
`simple_pytree.dataclass` is just a wrapper around `dataclasses.dataclass` but
when used static analysis tools and IDEs will understand that `static_field` is a
field specifier just like `dataclasses.field`.### Mutability
`Pytree` objects are immutable by default after `__init__`:```python
from simple_pytree import Pytree, static_fieldclass Foo(Pytree):
y = static_field()
def __init__(self, x, y):
self.x = x
self.y = yfoo = Foo(1, 2)
foo.x = 3 # AttributeError
```
If you want to make them mutable, you can use the `mutable` argument in class definition:```python
from simple_pytree import Pytree, static_fieldclass Foo(Pytree, mutable=True):
y = static_field()
def __init__(self, x, y):
self.x = x
self.y = yfoo = Foo(1, 2)
foo.x = 3 # OK
```### Replacing fields
If you want to make a copy of a `Pytree` object with some fields modified, you can use the `.replace()` method:
```python
from simple_pytree import Pytree, static_fieldclass Foo(Pytree):
y = static_field()
def __init__(self, x, y):
self.x = x
self.y = yfoo = Foo(1, 2)
foo = foo.replace(x=10)assert foo.x == 10 and foo.y == 2
````replace` works for both mutable and immutable `Pytree` objects. If the class
is a `dataclass`, `replace` internally use `dataclasses.replace`.