{"id":18423681,"url":"https://github.com/jejjohnson/jaxsw","last_synced_at":"2025-04-07T15:32:55.746Z","repository":{"id":38085089,"uuid":"498356793","full_name":"jejjohnson/jaxsw","owner":"jejjohnson","description":"Simple differentiable approximate ocean models built with JAX.","archived":false,"fork":false,"pushed_at":"2023-10-13T14:06:06.000Z","size":112240,"stargazers_count":13,"open_issues_count":32,"forks_count":1,"subscribers_count":4,"default_branch":"main","last_synced_at":"2025-03-22T20:33:34.606Z","etag":null,"topics":["differentiable-physics","jax","oceanography","pde","quasigeostrophy","shallow-water"],"latest_commit_sha":null,"homepage":"https://jejjohnson.github.io/jaxsw/","language":"Jupyter Notebook","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/jejjohnson.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null}},"created_at":"2022-05-31T13:51:49.000Z","updated_at":"2024-12-10T00:30:23.000Z","dependencies_parsed_at":"2024-01-13T07:30:20.739Z","dependency_job_id":null,"html_url":"https://github.com/jejjohnson/jaxsw","commit_stats":{"total_commits":38,"total_committers":3,"mean_commits":"12.666666666666666","dds":0.07894736842105265,"last_synced_commit":"c8fe60860db6042a75120160b7d84c7db6281417"},"previous_names":[],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jejjohnson%2Fjaxsw","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jejjohnson%2Fjaxsw/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jejjohnson%2Fjaxsw/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/jejjohnson%2Fjaxsw/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/jejjohnson","download_url":"https://codeload.github.com/jejjohnson/jaxsw/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":247679915,"owners_count":20978161,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["differentiable-physics","jax","oceanography","pde","quasigeostrophy","shallow-water"],"created_at":"2024-11-06T04:38:20.859Z","updated_at":"2025-04-07T15:32:50.711Z","avatar_url":"https://github.com/jejjohnson.png","language":"Jupyter Notebook","funding_links":[],"categories":[],"sub_categories":[],"readme":"# Simple Ocean Models in JAX\n\n## Motivation\n\nSea surface height is a gateway variable to other important ocean properties, e.g. sea surface temperature, geostrophic currents.\nThere are many massive models that attempt to model this, e.g. NEMO, MOM6, MITGCM. \nHowever they are very expensive and quite difficult to run. So there are many small models that are useful approximations, e.g. Quasi-Geostrophic and Shallow Water.\nThis repo attempts to showcase how we can use some modern tools to construct dynamical systems for PDEs.\n\nWhat makes this different from the tons and tons of different implementations is that we\nwill be using JAX. \nJAX is basically numpy on steroids because the API is very similar but we also get some of the modern toolsets along with speed.\nMost importantly, JAX is differentiable.\nHaving a differentiable model is important because it allows us to:\n\n* Learn some of the hyperparameters if necessary\n* Embed this in machine learning models where differentiability is needed\n\n**Why Not PyTorch?**\n\nWe could easily just use PyTorch. However, there are some advantanges to JAX over other languages like PyTorch and TensorFlow:\n\n* Familiar Numpy-Like API which is nice for newcomers in the scientific community\n* CPU/GPU/TPU capabilities with minimal code changes\n* Gradient Operators instead of storing the transformations in the tensors\n* Functional-like language which is easier to read for newcomers\n* Auto-Vectorization so we can easily parallize the operators for multiple dimensions without code changes (note: TensorFlow has this)\n* JIT compilation speeds up the code by a lot (note: both PyTorch and TensorFlow has this)\n\n---\n## Applications\n\nThis library will be relatively general but this will be a development platform for the following applications:\n\n* Generate Simulations\n* Surrogate Models\n* Data Assimilation\n\n---\n## Main Components\n\nWithout making it too complicated, we settled on a few key objects that the package will comprise of.\n\n**Domain**\n\nThis will be the object to define the grids where all of the fields live. It will be easy to access the coordinates, boundaries, grids and cell volumes. We don't need to store the grid all of the time, instead we just generate it as we see fit.\n\n**Operators**\n\nThis will be a suite of functions for different gradient calculations and combined operations for well-known equations. We will primarily focus on finite difference operators with the `finiteDiffX` package. At a later date, we can introduce spectral and finite volume methods.\n\n**Integrators**\n\nWe will use the `diffrax` package to do the time integration. We'll use the method-of-lines technique to formulate all of our PDEs to calculate the RHS of the equation for the state at $t$. Then we can propagate them through the time integrator to get the state at $t+1$.\n\n**Params, State \u0026 Equations of Motion**\n\nWe will have a general API for how we can keep store parameters, initialize states and pass thew both through the equation of motion. To handle what's differentiable and what is not, we will use the `equinox` package.\n\n\n**Configs**\n\nWe will use the `hydra` package to keep track of the configurations and to initialize parameters for experiments.\n\n---\n## Installation\n\n### pip\n\nWe can directly install it via pip from the\n\n```bash\npip install \"git+https://github.com/jejjohnson/jaxsw.git\"\n```\n\n### Cloning\n\nWe can also clone the git repository\n\n```bash\ngit clone https://github.com/jejjohnson/jaxsw.git\ncd jaxsw\n```\n\n#### poetry\n\nThe easiest way to get started is to simply use the poetry package which installs all necessary dev packages as well\n\n```bash\npoetry install\n```\n\n#### pip\n\nWe can also install via `pip` as well\n\n```bash\npip install .\n```\n\n### Conda\n\nWe also have a conda environment with all of the equivalent dependencies.\n\n```bash\nconda env create -f environments/jax_linux_cpu.yaml\nconda activate jaxsw\n```\n\n---\n## Contributions\n\n\n\n---\n## Acknowledgements\n\n* [`qg_utils`](https://github.com/bderembl/qgutils) - useful functions for dealing with QG equations\n* [`jaxdf`](https://github.com/ucl-bug/jaxdf) - Nice API for defining operators for PDEs.\n* [`jax-cfd`](https://github.com/google/jax-cfd) - Nice API for defining PDEs\n* [`invobs-data-assimilation`](https://github.com/googleinterns/invobs-data-assimilation) - Nice API for Dynamical Systems\n* [`MASSH`](https://github.com/leguillf/MASSH) - The differentiable QG and SW models applied to sea surface height interpolation.\n* [`qgm_pytorch`](https://github.com/louity/qgm_pytorch) - Quasi-Geostrophic Model in PyTorch\n* [`QGNet`](https://github.com/redouanelg/qgsw-DI/blob/master/QGNET/QG_PyTorch.ipynb) - QG implementation in PyTorch with convolutions.","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fjejjohnson%2Fjaxsw","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fjejjohnson%2Fjaxsw","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fjejjohnson%2Fjaxsw/lists"}