https://github.com/rose-stl-lab/dyffusion
[NeurIPS 2023] A Dynamics-informed Diffusion Model for Spatiotemporal Forecasting
https://github.com/rose-stl-lab/dyffusion
deep-learning diffusion diffusion-models ensemble-forecasts machine-learning neurips neurips-2023 probabilistic-forecasting pytorch pytorch-lightning spatiotemporal-forecasting
Last synced: 6 months ago
JSON representation
[NeurIPS 2023] A Dynamics-informed Diffusion Model for Spatiotemporal Forecasting
- Host: GitHub
- URL: https://github.com/rose-stl-lab/dyffusion
- Owner: Rose-STL-Lab
- License: apache-2.0
- Created: 2023-06-02T21:13:10.000Z (over 2 years ago)
- Default Branch: main
- Last Pushed: 2024-10-18T17:30:13.000Z (12 months ago)
- Last Synced: 2025-03-28T14:07:10.788Z (7 months ago)
- Topics: deep-learning, diffusion, diffusion-models, ensemble-forecasts, machine-learning, neurips, neurips-2023, probabilistic-forecasting, pytorch, pytorch-lightning, spatiotemporal-forecasting
- Language: Python
- Homepage: https://salvarc.github.io/blog/2023/dyffusion
- Size: 521 KB
- Stars: 193
- Watchers: 2
- Forks: 22
- Open Issues: 2
-
Metadata Files:
- Readme: README.md
- Contributing: CONTRIBUTING.md
- License: LICENSE
Awesome Lists containing this project
README
# DYffusion: A Dynamics-informed Diffusion Model for Spatiotemporal Forecasting (NeurIPS 2023)
✨Official implementation of our DYffusion paper✨
*DYffusion forecasts a sequence of* $h$ *snapshots* $\mathbf{x}_1, \mathbf{x}_2, \dots, \mathbf{x}_h$
*given the initial conditions* $\mathbf{x}_0$ *similarly to how standard diffusion models are used to sample from a distribution.*
If you use this code, please consider citing our work. Copy the bibtex from the bottom of this Readme or cite as:> [DYffusion: A Dynamics-informed Diffusion Model for Spatiotemporal Forecasting](https://arxiv.org/abs/2306.01984),\
Salva Rühling Cachay, Bo Zhao, Hailey Joren, and Rose Yu,\
*Advances in Neural Information Processing Systems (NeurIPS)*, 2023## | Environment Setup
We recommend installing ``dyffusion`` in a virtual environment from PyPi or Conda.
For more details about installing [PyTorch](https://pytorch.org/get-started/locally/), please refer to their official documentation.
For some compute setups you may want to install pytorch first for proper GPU support.python3 -m pip install ".[train]"
## | Downloading Data
**Navier-Stokes and spring mesh:**
Follow the instructions given by the [original dataset paper](https://github.com/karlotness/nn-benchmark).
Or, simply run our scripts to download the data. For Navier-Stokes: ``bash scripts/download_navier_stokes.sh``.
For spring mesh: ``bash scripts/download_spring_mesh.sh``.By default, the data are downloaded to ``$HOME/data/physical-nn-benchmark``,
you can override this by setting the ``DATA_DIR`` in the [scripts/download_physical_systems_data.sh](scripts/download_physical_systems_data.sh) script.**Sea surface temperatures:**
Pre-processed SST data can be downloaded from Zenodo: https://zenodo.org/record/7259555**IMPORTANT:** By default, our code expects the data to be in the ``$HOME/data/physical-nn-benchmark`` and ``$HOME/data/oisstv2`` directories.
Using a different data directory
If you want to use a different directory, you need to change the
`datamodule.data_dir` command line argument (e.g. `python run.py datamodule.data_dir=/path/to/data`), or
permanently edit the ``data_dir`` variable in the [src/configs/datamodule/_base_data_config.yaml](src/configs/datamodule/_base_data_config.yaml) file.## | Running experiments
Please see the [src/README.md](src/README.md) file for detailed instructions on how to run experiments, navigate the code and running with different configurations.
### Train DYffusion
**First stage:** Train the *interpolator* network. E.g. with
```
python run.py experiment=spring_mesh_interpolation
```**Second stage:** Train the *forecaster* network. E.g. with
```
python run.py experiment=spring_mesh_dyffusion diffusion.interpolator_run_id=
```
Note that we currently rely on Weights & Biases for logging and checkpointing,
so please note the wandb run id of the interpolator training run, so that you can use it to train the forecaster network as above.
You can find the run's ID, for example, in the URL of the run's page on wandb.ai.
E.g. in ``https://wandb.ai///runs/i73blbh0`` the run ID is ``i73blbh0``.#### Training DYffusion on your own data
We advise to create your own datamodule by following the example ones in [datamodules/](src/datamodules) and creating a
corresponding yaml config file in [configs/datamodule/](src/configs/datamodule).
*First stage:* It is worth spending some time/compute in optimizing the interpolator network (in terms of CRPS) before training the forecaster network.
To do so, it is important to sweep over the dropout rate(s).
But any other hyperparameter like the learning rate that leads to better CRPS will likely transfer to the overall performance of DYffusion as well.
*Second stage:*
The full set of possible configuration for training DYffusion/the forecaster net is defined and briefly explained in [src/configs/diffusion/dyffusion.yaml](src/configs/diffusion/dyffusion.yaml).
It can be useful to try out different values for ``forward_conditioning``,
check whether setting ``additional_interpolation_steps>0`` (i.e. ``k>0``) helps to improve the performance,
and enable ``refine_intermediate_predictions=True`` (you may do so after finishing training).### Wandb integration
We use [Weights & Biases](https://wandb.ai/) for logging and checkpointing.
Please set your wandb username/entity in the [src/configs/logger/wandb.yaml](src/configs/logger/wandb.yaml) file.
Alternatively, you can set the `logger.wandb.entity` command line argument (e.g. `python run.py logger.wandb.entity=my_username`).### Reproducing results
You can use any of the yaml configs in the [src/configs/experiment](src/configs/experiment) directory to (re-)run experiments.
Each experiment file name defines a particular dataset and method/model combination following the pattern ``_.yaml``.
For example, you can train the ``Dropout`` baseline on the spring mesh dataset with:python run.py experiment=spring_mesh_time_conditioned
Please note that to train DYffusion you need to start with the interpolation stage first, before running the ``_dyffusion`` experiment,
as described above.#### Testing a trained model
To test a trained model you, take note of its wandb run ID and then run:python run.py mode=test logger.wandb.id=
Alternatively, reload the model from a local checkpoint file with:
python run.py mode=test logger.wandb.id= ckpt_path=
It is important to set the `mode=test` flag, so that the model is tested appropriately (e.g. predict 50 samples per initial condition).
If you're using multiple wandb projects, you may also need to set the `logger.wandb.project` flag.### Debugging
By default, we use all training trajectories for training our models.
To debug the physical systems experiments, feel free to use fewer training trajectories by setting:
``python run.py datamodule.num_trajectories=1``. To accelerate training for the SST experiments, you may run with fewer
regional boxes (the default is 11 boxes) with ``python run.py 'datamodule.boxes=[88]'``.
Generally, you can also try mixed precision training with ``python run.py trainer.precision=16``.## | Citation
@inproceedings{cachay2023dyffusion,
title={{DYffusion:} A Dynamics-informed Diffusion Model for Spatiotemporal Forecasting},
author={R{\"u}hling Cachay, Salva and Zhao, Bo and Joren, Hailey and Yu, Rose},
booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
url={https://openreview.net/forum?id=WRGldGm5Hz},
year={2023}
}