https://github.com/titu1994/tf_star_rnn
Tensorflow 2.0 implementation of STAR RNN
https://github.com/titu1994/tf_star_rnn
Last synced: 6 months ago
JSON representation
Tensorflow 2.0 implementation of STAR RNN
- Host: GitHub
- URL: https://github.com/titu1994/tf_star_rnn
- Owner: titu1994
- License: mit
- Created: 2020-06-07T00:44:24.000Z (over 5 years ago)
- Default Branch: master
- Last Pushed: 2020-06-07T00:55:05.000Z (over 5 years ago)
- Last Synced: 2025-03-25T05:34:05.058Z (7 months ago)
- Language: Python
- Size: 109 KB
- Stars: 10
- Watchers: 3
- Forks: 1
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# STAR Recurrent Neural Network for Tensorflow 2.0
Tensorflow 2.0 implementation of STAckable Recurrent (STAR) neural networks from the paper [Gating Revisited: Deep Multi-layer RNNs That Can Be Trained](https://arxiv.org/abs/1911.11033).
Code ported from original authors implementation - https://github.com/0zgur0/STAR_Network
# Stackable RNN
## Usage
Import the two files `star_rnn.py` and `initializers.py`, then simply import and use as below :
```python
import tensorflow as tf
from star_rnn import STARCell# Model config
NUM_LAYERS = 4
DROPOUT = 0.1
time_dim = MAX_TIMESTEPS # should be a positive integer# Model definition
ip = tf.keras.layers.Input(shape=(time_dim, channel_dim))x = ip
states = None
for i in range(NUM_LAYERS - 1):
x, states = tf.keras.layers.RNN(STARCell(RNN_UNITS, t_max=time_dim, dropout=DROPOUT),
return_sequences=True, return_state=True)(x, initial_state=states)x = tf.keras.layers.RNN(STARCell(RNN_UNITS, t_max=time_dim, dropout=DROPOUT))(x, initial_state=states)
x = tf.keras.layers.Dense(num_classes, activation='linear', bias_initializer='he_uniform')(x)model = tf.keras.Model(inputs=ip, outputs=x)
```## Evaluation on Addition task
Run the `add_train.py` script.
# Citation
Please cite the authors of the paper
```
@article{turkoglu2019gating,
title={Gating Revisited: Deep Multi-layer RNNs That Can Be Trained},
author={Turkoglu, Mehmet Ozgur and D'Aronco, Stefano and Wegner, Jan Dirk and Schindler, Konrad},
journal={arXiv preprint arXiv:1911.11033},
year={2019}
}
```