Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/tcapelle/moving_mnist
Exploring Moving Mnist dataset with forecasting algorithms
https://github.com/tcapelle/moving_mnist
Last synced: 7 days ago
JSON representation
Exploring Moving Mnist dataset with forecasting algorithms
- Host: GitHub
- URL: https://github.com/tcapelle/moving_mnist
- Owner: tcapelle
- License: apache-2.0
- Created: 2020-07-10T08:07:13.000Z (over 4 years ago)
- Default Branch: master
- Last Pushed: 2023-04-12T05:47:06.000Z (over 1 year ago)
- Last Synced: 2024-06-11T19:15:30.271Z (5 months ago)
- Language: Jupyter Notebook
- Homepage: https://tcapelle.github.io/moving_mnist/
- Size: 15 MB
- Stars: 31
- Watchers: 1
- Forks: 5
- Open Issues: 3
-
Metadata Files:
- Readme: README.md
- Contributing: CONTRIBUTING.md
- License: LICENSE
Awesome Lists containing this project
README
# Moving MNIST forecasting
> A little experiment using Convolutional RNNs to forecast moving MNIST digits.```python
from fastai.vision.all import *
from moving_mnist.models.conv_rnn import *
from moving_mnist.data import *
``````python
if torch.cuda.is_available():
torch.cuda.set_device(0)
print(torch.cuda.get_device_name())
```Quadro RTX 8000
## Install
It only uses fastai (version 2) as dependency. Check how to install at https://github.com/fastai/fastai2
## Example:
We wil predict:
- `n_in`: 5 images
- `n_out`: 5 images
- `n_obj`: 3 objects```python
ds = MovingMNIST(DATA_PATH, n_in=5, n_out=5, n_obj=[1,2,3])
``````python
train_tl = TfmdLists(range(500), ImageTupleTransform(ds))
valid_tl = TfmdLists(range(100), ImageTupleTransform(ds))
``````python
dls = DataLoaders.from_dsets(train_tl, valid_tl, bs=8,
after_batch=[Normalize.from_stats(*mnist_stats)]).cuda()
```Left: Input, Right: Target
```python
dls.show_batch()
```![png](docs/images/output_10_0.png)
`StackUnstack` takes cares of stacking the list of images into a fat tensor, and unstacking them at the end, we will need to modify our loss function to take a list of tensors as input and target.
```python
model = StackUnstack(SimpleModel())
```As the `ImageSeq` is a `tuple` of images, we will need to stack them to compute loss.
```python
loss_func = StackLoss(MSELossFlat())
``````python
learn = Learner(dls, model, loss_func=loss_func, cbs=[])
``````python
learn.lr_find()
```SuggestedLRs(lr_min=0.005754399299621582, lr_steep=3.0199516913853586e-05)
![png](docs/images/output_16_2.png)
```python
learn.fit_one_cycle(4, 1e-4)
```
epoch
train_loss
valid_loss
time
0
0.915238
0.619522
00:12
1
0.669368
0.608123
00:12
2
0.570026
0.559723
00:12
3
0.528593
0.532774
00:12
```python
p,t = learn.get_preds()
```As you can see, the results is a list of 5 tensors with 100 samples each.
```python
len(p), p[0].shape
```(5, torch.Size([100, 1, 64, 64]))
```python
def show_res(t, idx):
im_seq = ImageSeq.create([t[i][idx] for i in range(5)])
im_seq.show(figsize=(8,4));
``````python
k = random.randint(0,100)
show_res(t,k)
show_res(p,k)
```![png](docs/images/output_22_0.png)
![png](docs/images/output_22_1.png)
## Training Example:
- ConvGRU with attention and blur upsampling: [01_train_example.ipynb](01_train_example.ipynb)
- ConvGRU trained with Cross Entropy instead of MSE: [02_train_cross_entropy.ipynb](02_train_cross_entropy.ipynb)
- Seq2seq model trianed with MSE [03_trainseq2seq.ipynb](03_trainseq2seq.ipynb)
- PhyDNet ported to fastai[04_train_phydnet.ipynb](04_train_phydnet.ipynb)