https://github.com/NVlabs/ConvSSM
https://github.com/NVlabs/ConvSSM
Last synced: about 2 months ago
JSON representation
- Host: GitHub
- URL: https://github.com/NVlabs/ConvSSM
- Owner: NVlabs
- License: other
- Created: 2023-10-25T22:11:49.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2024-10-22T11:14:35.000Z (7 months ago)
- Last Synced: 2024-10-28T08:41:51.511Z (7 months ago)
- Language: Python
- Size: 20.6 MB
- Stars: 58
- Watchers: 7
- Forks: 3
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
- Awesome-state-space-models - GitHub
README
# Convolutional State Space Models for Long-Range Spatiotemporal Modeling
This repository provides the official JAX implementation for the
paper:**Convolutional State Space Models for Long-Range Spatiotemporal Modeling** [[arXiv]](https://arxiv.org/abs/2310.19694)
[Jimmy T.H. Smith](https://jimmysmith1919.github.io/),
[Shalini De Mello](https://research.nvidia.com/person/shalini-de-mello),
[Jan Kautz](https://jankautz.com),
[Scott Linderman](https://web.stanford.edu/~swl1/),
[Wonmin Byeon](https://wonmin-byeon.github.io/),
NeurIPS 2023.
For business inquiries, please visit the NVIDIA website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/).
---
We introduce an efficient long-range spatiotemporal sequence modeling method, **ConvSSM**. It is parallelizable and overcomes major limitations of the traditional ConvRNN (e.g., vanishing/exploding gradient problems) while providing an unbounded context and fast autoregressive generation compared to Transformers. It performs similarly or better than Transformers/ConvLSTM on long-horizon video prediction tasks, trains up to 3× faster than ConvLSTM, and generates samples up to 400× faster than Transformers. We provide the results for the long horizon Moving-MNIST generation task and long-range 3D environment benchmarks (DMLab, Minecraft, and Habitat).

The repository builds on the training pipeline from [TECO](https://github.com/wilson1yan/teco).
---
### Installation
You will need to install JAX following the instructions [here](https://jax.readthedocs.io/en/latest/installation.html).
We used JAX version 0.3.21.
```commandline
pip install --upgrade jax[cuda]==0.3.21 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```Then install the rest of the dependencies with:
```commandline
sudo apt-get update && sudo apt-get install -y ffmpeg
pip install -r requirements.txt
pip install -e .
```---
### Datasets
For `Moving-Mnist`:1) Download the MNIST binary file.
```commandline
wget http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz -O data/moving-mnist-pytorch/train-images-idx3-ubyte.gz
```
2) Use the script in `data/moving-mnist-pytorch` to generate the Moving MNIST data.For 3D Environment tasks:
We used the scripts from the [TECO](https://github.com/wilson1yan/teco) repository to download the datasets; [`DMLab`](https://github.com/wilson1yan/teco/blob/master/scripts/download/dmlab.sh) and
[`Habitat`](https://github.com/wilson1yan/teco/blob/master/scripts/download/habitat.sh). Check the TECO repository for the details of the datasets.The data should be split into 'train' and 'test' folders.
---
### Pretrained VQ-GANs:
Pretrained VQ-GAN checkpoints for each dataset can be found [here](https://drive.google.com/drive/folders/10hAqVjoxte9OxYc7WIih_5OtwbdOxKoi). Note these are also from [TECO](https://github.com/wilson1yan/teco).---
### Pretrained ConvS5 checkpoints:
Pretrained ConvS5 checkpoints for each dataset can be found [here](https://www.dropbox.com/scl/fo/h3omm0bc3dau9uh9cgrq0/AICA1umpuN1LRG_MRwUyPWU?rlkey=s9w4d3ncsfz39n2r390dpbsk2&st=v722uk6x&dl=0). Download the checkpoints to the checkpoint_directories.
Default checkpoint_directory: `logs//checkpoints/`| dataset | checkpoint | config |
|:---:|:---:|:---:|
| Moving-Mnist 300 | [link](https://www.dropbox.com/scl/fo/wg6f4cazhlw5cs3fjfapf/AERdhoK8HARlwwlGvu8ZpLY?rlkey=spq6umv7m2scywntwgqswxvys&st=hpbemnp2&dl=0) | `Moving-MNIST/300_train_len/mnist_convS5_novq.yaml` |
| Moving-Mnist 600 | [link](https://www.dropbox.com/scl/fo/1vog37ntlr67084o6qbpm/AD8B3ZZIek9pxDvb80rhd4k?rlkey=krp36u6zbu8nac4foml9f3bq1&st=hmuymguf&dl=0) | `Moving-MNIST/600_train_len/mnist_convS5_novq.yaml` |
| DMLab | [link](https://www.dropbox.com/scl/fo/dcy9nhw0umbowang36po1/AKPtYSxP2ynJnUDvoZsdxqc?rlkey=cw7bas02w2mw7ldyephu3w9r7&st=20wwyamh&dl=0) | `3D_ENV_BENCHMARK/dmlab/dmlab_convs5.yaml` |
| Habitat | [link](https://www.dropbox.com/scl/fo/6k6tchauqaguilkr8rb7c/ACyEPN_X00f1xWM_RQFyDF8?rlkey=gx5o11o9n5npfj09gxac8hq2p&st=9k23lyfk&dl=0) | `3D_ENV_BENCHMARK/habitat/habitat_teco_convS5.yaml` |
| Minecraft | [link](https://www.dropbox.com/scl/fo/c4g2ol85hbt58kveoklek/AJvYilfaNdaap1V89u5Z5Oo?rlkey=1pv457c6bal7t2pisqx20s51i&st=8ft9r7d0&dl=0) | `3D_ENV_BENCHMARK/minecraft/minecraft_teco_convS5.yaml` |---
### Training
Before training, you will need to update the paths to the corresponding configs files to point to your dataset and VQ-GAN directories.To train, run:
`python scripts/train.py -d -o -c `Example for training ConvS5 on DMLAB:
```commandline
python scripts/train.py -d datasets/dmlab -o dmlab_convs5 -c configs/3D_ENV_BENCHMARK/dmlab/dmlab_convs5.yaml
```Note: we only used data parallel training for our experiments. Model parallel training will require implementing JAX [xmap](https://jax.readthedocs.io/en/latest/notebooks/xmap_tutorial.html) or [pjit/jit](https://jax.readthedocs.io/en/latest/jax-101/08-pjit.html). See [this](https://github.com/wilson1yan/teco/tree/master/teco/models/xmap) folder in the TECO repo for an example using xmap.
Our runs were performed in a multinode NVIDIA V100 32GB GPU environment.
---
### Evaluation
To evaluate run:
`python scripts/eval.py -d -o -c `Example for evaluating ConvS5 on DMLAB:
```commandline
python scripts/eval.py -d datasets/dmlab -o dmlab_convs5 -c configs/3D_ENV_BENCHMARK/dmlab/dmlab_convs5_eval.yaml
```This will perform the sampling required for computing the different evaluation metrics. The videos will be saved into `npz` files.
For FVD evaluations run: `python scripts/compute_fvd.py `
Example for ConvS5 on DMLAB:
```commandline
python scripts/compute_fvd.py logs/dmlab_convs5/samples_36
```For PSNR, SSIM, and LPIPS run: `python scripts/compute_metrics.py `
Example for ConvS5 on DMLAB:
```commandline
python scripts/compute_metrics.py logs/dmlab_convs5/samples_action_144
```---
### Citation
Please use the following when citing our work:```BiBTeX
@inproceedings{
smith2023convolutional,
title={Convolutional State Space Models for Long-Range Spatiotemporal Modeling},
author={Jimmy T.H. Smith and Shalini De Mello and Jan Kautz and Scott Linderman and Wonmin Byeon},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
url={https://openreview.net/forum?id=1ZvEtnrHS1}
}
```---
### License
Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
See LICENSE file for details.Please reach out if you have any questions.
-- The ConvS5 authors.