Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/jejjohnson/jaxsw
Simple differentiable approximate ocean models built with JAX.
https://github.com/jejjohnson/jaxsw
differentiable-physics jax oceanography pde quasigeostrophy shallow-water
Last synced: 11 days ago
JSON representation
Simple differentiable approximate ocean models built with JAX.
- Host: GitHub
- URL: https://github.com/jejjohnson/jaxsw
- Owner: jejjohnson
- License: mit
- Created: 2022-05-31T13:51:49.000Z (over 2 years ago)
- Default Branch: main
- Last Pushed: 2023-10-13T14:06:06.000Z (about 1 year ago)
- Last Synced: 2024-07-06T01:27:11.493Z (4 months ago)
- Topics: differentiable-physics, jax, oceanography, pde, quasigeostrophy, shallow-water
- Language: Jupyter Notebook
- Homepage: https://jejjohnson.github.io/jaxsw/
- Size: 107 MB
- Stars: 11
- Watchers: 4
- Forks: 1
- Open Issues: 32
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Simple Ocean Models in JAX
## Motivation
Sea surface height is a gateway variable to other important ocean properties, e.g. sea surface temperature, geostrophic currents.
There are many massive models that attempt to model this, e.g. NEMO, MOM6, MITGCM.
However they are very expensive and quite difficult to run. So there are many small models that are useful approximations, e.g. Quasi-Geostrophic and Shallow Water.
This repo attempts to showcase how we can use some modern tools to construct dynamical systems for PDEs.What makes this different from the tons and tons of different implementations is that we
will be using JAX.
JAX is basically numpy on steroids because the API is very similar but we also get some of the modern toolsets along with speed.
Most importantly, JAX is differentiable.
Having a differentiable model is important because it allows us to:* Learn some of the hyperparameters if necessary
* Embed this in machine learning models where differentiability is needed**Why Not PyTorch?**
We could easily just use PyTorch. However, there are some advantanges to JAX over other languages like PyTorch and TensorFlow:
* Familiar Numpy-Like API which is nice for newcomers in the scientific community
* CPU/GPU/TPU capabilities with minimal code changes
* Gradient Operators instead of storing the transformations in the tensors
* Functional-like language which is easier to read for newcomers
* Auto-Vectorization so we can easily parallize the operators for multiple dimensions without code changes (note: TensorFlow has this)
* JIT compilation speeds up the code by a lot (note: both PyTorch and TensorFlow has this)---
## ApplicationsThis library will be relatively general but this will be a development platform for the following applications:
* Generate Simulations
* Surrogate Models
* Data Assimilation---
## Main ComponentsWithout making it too complicated, we settled on a few key objects that the package will comprise of.
**Domain**
This will be the object to define the grids where all of the fields live. It will be easy to access the coordinates, boundaries, grids and cell volumes. We don't need to store the grid all of the time, instead we just generate it as we see fit.
**Operators**
This will be a suite of functions for different gradient calculations and combined operations for well-known equations. We will primarily focus on finite difference operators with the `finiteDiffX` package. At a later date, we can introduce spectral and finite volume methods.
**Integrators**
We will use the `diffrax` package to do the time integration. We'll use the method-of-lines technique to formulate all of our PDEs to calculate the RHS of the equation for the state at $t$. Then we can propagate them through the time integrator to get the state at $t+1$.
**Params, State & Equations of Motion**
We will have a general API for how we can keep store parameters, initialize states and pass thew both through the equation of motion. To handle what's differentiable and what is not, we will use the `equinox` package.
**Configs**
We will use the `hydra` package to keep track of the configurations and to initialize parameters for experiments.
---
## Installation### pip
We can directly install it via pip from the
```bash
pip install "git+https://github.com/jejjohnson/jaxsw.git"
```### Cloning
We can also clone the git repository
```bash
git clone https://github.com/jejjohnson/jaxsw.git
cd jaxsw
```#### poetry
The easiest way to get started is to simply use the poetry package which installs all necessary dev packages as well
```bash
poetry install
```#### pip
We can also install via `pip` as well
```bash
pip install .
```### Conda
We also have a conda environment with all of the equivalent dependencies.
```bash
conda env create -f environments/jax_linux_cpu.yaml
conda activate jaxsw
```---
## Contributions---
## Acknowledgements* [`qg_utils`](https://github.com/bderembl/qgutils) - useful functions for dealing with QG equations
* [`jaxdf`](https://github.com/ucl-bug/jaxdf) - Nice API for defining operators for PDEs.
* [`jax-cfd`](https://github.com/google/jax-cfd) - Nice API for defining PDEs
* [`invobs-data-assimilation`](https://github.com/googleinterns/invobs-data-assimilation) - Nice API for Dynamical Systems
* [`MASSH`](https://github.com/leguillf/MASSH) - The differentiable QG and SW models applied to sea surface height interpolation.
* [`qgm_pytorch`](https://github.com/louity/qgm_pytorch) - Quasi-Geostrophic Model in PyTorch
* [`QGNet`](https://github.com/redouanelg/qgsw-DI/blob/master/QGNET/QG_PyTorch.ipynb) - QG implementation in PyTorch with convolutions.