https://github.com/beekill95/predrnn_tf
PredRNN implementation using Tensorflow.
https://github.com/beekill95/predrnn_tf
deep-learning lstm predictive-learning predrnn rnn rnn-tensorflow spatial-temporal tensorflow
Last synced: 8 months ago
JSON representation
PredRNN implementation using Tensorflow.
- Host: GitHub
- URL: https://github.com/beekill95/predrnn_tf
- Owner: beekill95
- License: mit
- Created: 2023-04-06T17:20:28.000Z (over 2 years ago)
- Default Branch: main
- Last Pushed: 2023-05-25T18:20:40.000Z (over 2 years ago)
- Last Synced: 2025-01-13T19:47:55.301Z (9 months ago)
- Topics: deep-learning, lstm, predictive-learning, predrnn, rnn, rnn-tensorflow, spatial-temporal, tensorflow
- Language: Python
- Homepage:
- Size: 129 KB
- Stars: 1
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# PredRNN implementation using Tensorflow.
This is an implementation of PredRNN using Tensorflow.
The implementation is based on:* [PredRNN: recurrent neural networks for predictive learning using spatiotemporal LSTMs](https://dl.acm.org/doi/abs/10.5555/3294771.3294855)
* [PredRNN: A Recurrent Neural Network for Spatiotemporal Predictive Learning](https://arxiv.org/abs/2103.09504)
* [Official implementation in Pytorch](https://github.com/thuml/predrnn-pytorch)
* [kami93/PredRNN](https://github.com/kami93/PredRNN)Specifically, this repo implements the second paper:
(stacked) spatial temporal cell and reversed scheduled sampling.
There is an implementation for decouple loss;
however, due to a bug with Keras's `add_loss()` function,
enabling the loss will raise error during runtime.## Accuracy
Performance (MSE) on Moving MNIST dataset
(results produced by running `scripts/moving_mnist_all_batch.sh 5`)_| | Run #1 | Run #2 | Run #3 | Run #4 | Run #5 | Mean | Std |
|----------------------------------------|---------|---------|---------|---------|---------|----------|-----------------|
| ConvLSTM | 0.03618 | 0.00951 | 0.04391 | 0.00956 | 0.03485 | 0.026802 | 0.01613769407 |
| PredRNN without Scheduled Sampling | 0.00705 | 0.00728 | 0.00703 | 0.00723 | 0.00733 | 0.007184 | 0.0001363084737 |
| PredRNN with Linear Scheduled Sampling | 0.00798 | 0.00741 | 0.01012 | 0.01468 | 0.00822 | 0.009682 | 0.00297355343 |Performance (MSE) on Moving MNIST dataset with different scheduled sampling strategies.
| | Run #1 | Run #2 | Run #3 | Run #4 | Run #5 | Mean | Std |
|-----------------------------------------|---------|---------|---------|---------|---------|----------|-----------------|
| PredRNN with Linear Scheduled Sampling | 0.01254 | 0.01095 | 0.01231 | 0.01032 | 0.01071 | 0.011366 | 0.0009958564154 |
| PredRNN with Expo Scheduled Sampling | 0.00816 | 0.00868 | 0.01382 | 0.01027 | 0.01094 | 0.010374 | 0.002234810954 |
| PredRNN with Sigmoid Scheduled Sampling | 0.00938 | 0.00999 | 0.00797 | 0.01007 | 0.00775 | 0.009032 | 0.001105404903 |## Installation
The repository can be installed as module using either `pip` or `poetry`.
## Development
The repo uses `poetry` to manage dependencies.
To install all the development dependencies, use `poetry install`.## Examples
In the folder `examples`,
there are some notebooks that create PredRNN model
and use it on Moving MNIST dataset.These examples can be opened as Jupyter notebooks using Jupytext.
These development dependencies will be installed along with
required dependencies with `poetry install`.Since these examples are resource-intensive and time-consuming,
it is recommended to run these examples in GPU clusters.
In the `scripts` folder,
there are some scripts to run these in IU's Carbornate GPU clusters
(you will authorization to use these clusters, of course!).
Specifically,* `sbatch scripts/moving_mnist.sbatch` will schedule a batch job to run
example `examples/moving_mnist_predrnn.py`.
The output notebook will be `examples/moving_mnist_predrnn_.ipynb`
and the model will be saved as `saved_models/moving_mnist_predrnn_`.
* Similarly for `sbatch scripts/moving_mnist_ss.sbatch`.
* `scripts/moving_mnist_all_batch.sh ` will submit a series
of jobs (3 * ) to the GPU clusters.
There will be jobs for each combination of PredRNN/ConvLSTM
and with/without scheduled sampling.
The goal of this script is to show whether PredRNN is better than ConvLSTM
at predicting future frames for Moving MNIST dataset.