{"id":46124994,"url":"https://github.com/allen-adastra/xarray_jax","last_synced_at":"2026-03-02T01:36:38.546Z","repository":{"id":255926656,"uuid":"853904203","full_name":"allen-adastra/xarray_jax","owner":"allen-adastra","description":"Simple Xarray + JAX Integration","archived":false,"fork":false,"pushed_at":"2025-09-16T22:31:45.000Z","size":237,"stargazers_count":21,"open_issues_count":1,"forks_count":1,"subscribers_count":1,"default_branch":"main","last_synced_at":"2025-09-17T00:43:42.262Z","etag":null,"topics":["jax","xarray"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/allen-adastra.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null,"notice":null,"maintainers":null,"copyright":null,"agents":null,"dco":null,"cla":null}},"created_at":"2024-09-07T21:25:24.000Z","updated_at":"2025-09-16T22:31:49.000Z","dependencies_parsed_at":"2024-09-07T22:35:46.311Z","dependency_job_id":"b88106d1-59f4-4825-8766-2f2bbfc7a241","html_url":"https://github.com/allen-adastra/xarray_jax","commit_stats":null,"previous_names":["allen-adastra/xarray_jax"],"tags_count":1,"template":false,"template_full_name":null,"purl":"pkg:github/allen-adastra/xarray_jax","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/allen-adastra%2Fxarray_jax","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/allen-adastra%2Fxarray_jax/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/allen-adastra%2Fxarray_jax/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/allen-adastra%2Fxarray_jax/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/allen-adastra","download_url":"https://codeload.github.com/allen-adastra/xarray_jax/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/allen-adastra%2Fxarray_jax/sbom","scorecard":null,"host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":286080680,"owners_count":29989368,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2026-03-01T22:42:38.399Z","status":"ssl_error","status_checked_at":"2026-03-01T22:41:51.863Z","response_time":124,"last_error":"SSL_read: unexpected eof while reading","robots_txt_status":"success","robots_txt_updated_at":"2025-07-24T06:49:26.215Z","robots_txt_url":"https://github.com/robots.txt","online":false,"can_crawl_api":true,"host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["jax","xarray"],"created_at":"2026-03-02T01:36:37.792Z","updated_at":"2026-03-02T01:36:38.541Z","avatar_url":"https://github.com/allen-adastra.png","language":"Python","readme":"# Simple Xarray + JAX Integration\n\nThis is an experiment at integrating Xarray + JAX in a simple way.\n\n``` python\nimport jax.numpy as jnp\nimport xarray as xr\nimport xarray_jax as xj\n\n# Construct a DataArray.\nda = xr.DataArray(\n    xr.Variable([\"x\", \"y\"], jnp.ones((2, 3))),\n    coords={\"x\": [1, 2], \"y\": [3, 4, 5]},\n    name=\"foo\",\n    attrs={\"attr1\": \"value1\"},\n)\n\n# Do some operations inside a JIT compiled function.\n@eqx.filter_jit\ndef some_function(data):\n    neg_data = -1.0 * data\n    return neg_data * neg_data.coords[\"y\"]  # Multiply data by coords.\n\nda = some_function(da)\n\n# Construct a xr.DataArray with dummy data (useful for tree manipulation).\nda_mask = jax.tree.map(lambda _: True, da)\n\n# Take the gradient of a jitted function.\n@eqx.filter_jit\ndef fn(data):\n    return (data**2.0).sum().data\n\nda_grad = jax.grad(fn)(da)\n\n# Convert to a custom XjDataArray, implemented as an equinox module.\n# (Useful for avoiding potentially weird xarray interactions with JAX).\nxj_da = xj.from_xarray(da)\n\n# Convert back to a xr.DataArray.\nda = xj.to_xarray(xj_da)\n\n# Use xj.var_change_on_unflatten to allow us to expand the dimensions of the DataArray.\ndef add_dim_to_var(var):\n    var._dims = (\"new_dim\", *var._dims)\n    return var\n\nwith xj.var_change_on_unflatten(add_dim_to_var):\n    da = jax.tree.map(lambda x: jnp.expand_dims(x, axis=0), da)\n```\n## Installation\n```bash\npip install xarray_jax\n```\n\n## Status\n- [x] PyTree node registrations\n  - [x] `xr.Variable`\n  - [x] `xr.DataArray`\n  - [x] `xr.Dataset`\n- [x] Minimal shadow types implemented as [equinox modules](https://github.com/patrick-kidger/equinox) to handle edge cases (Note: these types are merely data structures that contain the data of these types. They don't have any of the methods of the xarray types).\n  - [x] `XjVariable`\n  - [x] `XjDataArray`\n  - [x] `XjDataset`\n- [x] `xj.from_xarray` and `xj.to_xarray` functions to go between `xj` and `xr` types.\n- [x] Support for `xr` types with dummy data (useful for tree manipulation).\n- [x] Support for transformations that change the dimensionality of the data using the `var_change_on_unflatten` context manager.\n\n## Sharp Edges\n\n### Prefer `eqx.filter_jit` over `jax.jit`\nThere are some edge cases with metadata that `eqx.filter_jit` handles but `jax.jit` does not.\n\n### Dispatching to jnp is not supported yet\nPending resolution of https://github.com/pydata/xarray/issues/7848.\n``` python\nvar = xr.Variable(dims=(\"x\", \"y\"), data=jnp.ones((3, 3)))\n\n# This will fail.\njnp.square(var)\n\n# This will work.\nxr.apply_ufunc(jnp.square, var)\n```\n\n\n## Distinction from the GraphCast Implementation\nThis experiment is largely inspired by the [GraphCast implementation](https://github.com/google-deepmind/graphcast/blob/main/graphcast/xarray_jax.py), with a direct re-use of the `_HashableCoords` in that project.\n\nHowever, this experiment aims to:\n1. Take a more minimialist approach (and thus neglects some features such as support JAX arrays as coordinates).\n2. Find a solution more compatible with common JAX PyTree manipulation patterns that trigger errors with Xarray types. For example, it's common to use boolean masks to filter out elements of a PyTree, but this tends to fail with Xarray types.\n\n## Acknowledgements\nThis repo was made possible by great discussions within the JAX + Xarray open source community, especially [this one](https://github.com/pydata/xarray/discussions/8164). In particular, the author would like to acknowledge @shoyer, @mjwillson, and @TomNicholas.\n","funding_links":[],"categories":[],"sub_categories":[],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fallen-adastra%2Fxarray_jax","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fallen-adastra%2Fxarray_jax","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fallen-adastra%2Fxarray_jax/lists"}