{"id":15292155,"url":"https://github.com/google-research/dataclass_array","last_synced_at":"2025-03-22T13:08:23.158Z","repository":{"id":46325091,"uuid":"510297542","full_name":"google-research/dataclass_array","owner":"google-research","description":"Dataclasses manipulated as numpy arrays (with batching, reshape, slicing,...)","archived":false,"fork":false,"pushed_at":"2024-09-16T14:54:44.000Z","size":191,"stargazers_count":49,"open_issues_count":1,"forks_count":3,"subscribers_count":7,"default_branch":"main","last_synced_at":"2025-03-14T19:07:18.630Z","etag":null,"topics":["dataclass","dataclasses","jax","numpy","tensorflow","torch"],"latest_commit_sha":null,"homepage":"https://dataclass-array.readthedocs.io","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/google-research.png","metadata":{"files":{"readme":"README.md","changelog":"CHANGELOG.md","contributing":"CONTRIBUTING.md","funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2022-07-04T09:39:07.000Z","updated_at":"2024-12-30T22:27:07.000Z","dependencies_parsed_at":"2023-12-24T23:35:46.861Z","dependency_job_id":"60642f89-2c55-449a-bb24-8601d88c0902","html_url":"https://github.com/google-research/dataclass_array","commit_stats":{"total_commits":60,"total_committers":5,"mean_commits":12.0,"dds":0.09999999999999998,"last_synced_commit":"4f6012805ae25c2be3006b160d7ae7d4281dd7e7"},"previous_names":[],"tags_count":11,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-research%2Fdataclass_array","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-research%2Fdataclass_array/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-research%2Fdataclass_array/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/google-research%2Fdataclass_array/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/google-research","download_url":"https://codeload.github.com/google-research/dataclass_array/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":244959455,"owners_count":20538628,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["dataclass","dataclasses","jax","numpy","tensorflow","torch"],"created_at":"2024-09-30T16:16:49.185Z","updated_at":"2025-03-22T13:08:23.116Z","avatar_url":"https://github.com/google-research.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# Dataclass Array\n\n[![Unittests](https://github.com/google-research/dataclass_array/actions/workflows/pytest_and_autopublish.yml/badge.svg)](https://github.com/google-research/dataclass_array/actions/workflows/pytest_and_autopublish.yml)\n[![PyPI version](https://badge.fury.io/py/dataclass_array.svg)](https://badge.fury.io/py/dataclass_array)\n[![Documentation Status](https://readthedocs.org/projects/dataclass-array/badge/?version=latest)](https://dataclass-array.readthedocs.io/en/latest/?badge=latest)\n\n\n`DataclassArray` are dataclasses which behave like numpy-like arrays (can be\nbatched, reshaped, sliced,...), compatible with Jax, TensorFlow, and numpy (with\ntorch support planned).\n\nThis reduce boilerplate and improve readability. See the\n[motivating examples](#motivating-examples) section bellow.\n\nTo view an example of dataclass arrays used in practice, see\n[visu3d](https://github.com/google-research/visu3d).\n\n## Documentation\n\n### Definition\n\nTo create a `dca.DataclassArray`, take a frozen dataclass and:\n\n*   Inherit from `dca.DataclassArray`\n*   Annotate the fields with `dataclass_array.typing` to specify the inner shape\n    and dtype of the array (see below for static or nested dataclass fields).\n    The array types are an alias from\n    [`etils.array_types`](https://github.com/google/etils/blob/main/etils/array_types/README.md).\n\n```python\nimport dataclass_array as dca\nfrom dataclass_array.typing import FloatArray\n\n\nclass Ray(dca.DataclassArray):\n  pos: FloatArray['*batch_shape 3']\n  dir: FloatArray['*batch_shape 3']\n```\n\n### Usage\n\nAfterwards, the dataclass can be used as a numpy array:\n\n```python\nray = Ray(pos=jnp.zeros((3, 3)), dir=jnp.eye(3))\n\n\nray.shape == (3,)  # 3 rays batched together\nray.pos.shape == (3, 3)  # Individual fields still available\n\n# Numpy slicing/indexing/masking\nray = ray[..., 1:2]\nray = ray[norm(ray.dir) \u003e 1e-7]\n\n# Shape transformation\nray = ray.reshape((1, 3))\nray = ray.reshape('h w -\u003e w h')  # Native einops support\nray = ray.flatten()\n\n# Stack multiple dataclass arrays together\nray = dca.stack([ray0, ray1, ...])\n\n# Supports TF, Jax, Numpy (torch planned) and can be easily converted\nray = ray.as_jax()  # as_np(), as_tf()\nray.xnp == jax.numpy  # `numpy`, `jax.numpy`, `tf.experimental.numpy`\n\n# Compatibility `with jax.tree_util`, `jax.vmap`,..\nray = jax.tree_util.tree_map(lambda x: x+1, ray)\n```\n\nA `DataclassArray` has 2 types of fields:\n\n*   Array fields: Fields batched like numpy arrays, with reshape, slicing,...\n    Can be `xnp.ndarray` or nested `dca.DataclassArray`.\n*   Static fields: Other non-numpy field. Are not modified by reshaping,...\n    Static fields are also ignored in `jax.tree.map`.\n\n```python\nclass MyArray(dca.DataclassArray):\n  # Array fields\n  a: FloatArray['*batch_shape 3']  # Defined by `etils.array_types`\n  b: FloatArray['*batch_shape _ _']  # Dynamic shape\n  c: Ray  # Nested DataclassArray (equivalent to `Ray['*batch_shape']`)\n  d: Ray['*batch_shape 6']\n\n  # Array fields explicitly defined\n  e: Any = dca.field(shape=(3,), dtype=np.float32)\n  f: Any = dca.field(shape=(None,  None), dtype=np.float32)  # Dynamic shape\n  g: Ray = dca.field(shape=(3,), dtype=Ray)  # Nested DataclassArray\n\n  # Static field (everything not defined as above)\n  static0: float\n  static1: np.array\n```\n\n### Vectorization\n\n`@dca.vectorize_method` allow your dataclass method to automatically support\nbatching:\n\n1.  Implement method as if `self.shape == ()`\n2.  Decorate the method with `dca.vectorize_method`\n\n```python\nclass Camera(dca.DataclassArray):\n  K: FloatArray['*batch_shape 4 4']\n  resolution = tuple[int, int]\n\n  @dca.vectorize_method\n  def rays(self) -\u003e Ray:\n    # Inside `@dca.vectorize_method` shape is always guarantee to be `()`\n    assert self.shape == ()\n    assert self.K.shape == (4, 4)\n\n    # Compute the ray as if there was only a single camera\n    return Ray(pos=..., dir=...)\n```\n\nAfterward, we can generate rays for multiple camera batched together:\n\n```python\ncams = Camera(K=K)  # K.shape == (num_cams, 4, 4)\nrays = cams.rays()  # Generate the rays for all the cameras\n\ncams.shape == (num_cams,)\nrays.shape == (num_cams, h, w)\n```\n\n`@dca.vectorize_method` is similar to `jax.vmap` but:\n\n*   Only work on `dca.DataclassArray` methods\n*   Instead of vectorizing a single axis, `@dca.vectorize_method` will vectorize\n    over `*self.shape` (not just `self.shape[0]`). This is like if `vmap` was\n    applied to `self.flatten()`\n*   When multiple arguments, axis with dimension `1` are broadcasted.\n\nFor example, with `__matmul__(self, x: T) -\u003e T`:\n\n```python\n() @ (*x,) -\u003e (*x,)\n(b,) @ (b, *x) -\u003e (b, *x)\n(b,) @ (1, *x) -\u003e (b, *x)\n(1,) @ (b, *x) -\u003e (b, *x)\n(b, h, w) @ (b, h, w, *x) -\u003e (b, h, w, *x)\n(1, h, w) @ (b, 1, 1, *x) -\u003e (b, h, w, *x)\n(a, *x) @ (b, *x) -\u003e Error: Incompatible a != b\n```\n\nTo test on Colab, see the `visu3d` dataclass\n[Colab tutorial](https://colab.research.google.com/github/google-research/visu3d/blob/main/docs/dataclass.ipynb).\n\n## Motivating examples\n\n`dca.DataclassArray` improve readability by simplifying common patterns:\n\n*   Reshaping all fields of a dataclass:\n\n    Before (`rays` is simple `dataclass`):\n\n    ```python\n    num_rays = math.prod(rays.origins.shape[:-1])\n    rays = jax.tree.map(lambda r: r.reshape((num_rays, -1)), rays)\n    ```\n\n    After (`rays` is `DataclassArray`):\n\n    ```python\n    rays = rays.flatten()  # (b, h, w) -\u003e (b*h*w,)\n    ```\n\n*   Rendering a video:\n\n    Before (`cams: list[Camera]`):\n\n    ```python\n    img = cams[0].render(scene)\n    imgs = np.stack([cam.render(scene) for cam in cams[::2]])\n    imgs = np.stack([cam.render(scene) for cam in cams])\n    ```\n\n    After (`cams: Camera` with `cams.shape == (num_cams,)`):\n\n    ```python\n    img = cams[0].render(scene)  # Render only the first camera (to debug)\n    imgs = cams[::2].render(scene)  # Render 1/2 frames (for quicker iteration)\n    imgs = cams.render(scene)  # Render all cameras at once\n    ```\n\n## Installation\n\n```sh\npip install dataclass_array\n```\n\n*This is not an official Google product*\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fgoogle-research%2Fdataclass_array","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fgoogle-research%2Fdataclass_array","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fgoogle-research%2Fdataclass_array/lists"}