An open API service indexing awesome lists of open source software.

https://github.com/titu1994/tf_star_rnn

Tensorflow 2.0 implementation of STAR RNN
https://github.com/titu1994/tf_star_rnn

Last synced: 6 months ago
JSON representation

Tensorflow 2.0 implementation of STAR RNN

Awesome Lists containing this project

README

          

# STAR Recurrent Neural Network for Tensorflow 2.0

Tensorflow 2.0 implementation of STAckable Recurrent (STAR) neural networks from the paper [Gating Revisited: Deep Multi-layer RNNs That Can Be Trained](https://arxiv.org/abs/1911.11033).

Code ported from original authors implementation - https://github.com/0zgur0/STAR_Network

# Stackable RNN

## Usage

Import the two files `star_rnn.py` and `initializers.py`, then simply import and use as below :

```python
import tensorflow as tf
from star_rnn import STARCell

# Model config
NUM_LAYERS = 4
DROPOUT = 0.1
time_dim = MAX_TIMESTEPS # should be a positive integer

# Model definition
ip = tf.keras.layers.Input(shape=(time_dim, channel_dim))

x = ip
states = None
for i in range(NUM_LAYERS - 1):
x, states = tf.keras.layers.RNN(STARCell(RNN_UNITS, t_max=time_dim, dropout=DROPOUT),
return_sequences=True, return_state=True)(x, initial_state=states)

x = tf.keras.layers.RNN(STARCell(RNN_UNITS, t_max=time_dim, dropout=DROPOUT))(x, initial_state=states)
x = tf.keras.layers.Dense(num_classes, activation='linear', bias_initializer='he_uniform')(x)

model = tf.keras.Model(inputs=ip, outputs=x)
```

## Evaluation on Addition task

Run the `add_train.py` script.

# Citation

Please cite the authors of the paper

```
@article{turkoglu2019gating,
title={Gating Revisited: Deep Multi-layer RNNs That Can Be Trained},
author={Turkoglu, Mehmet Ozgur and D'Aronco, Stefano and Wegner, Jan Dirk and Schindler, Konrad},
journal={arXiv preprint arXiv:1911.11033},
year={2019}
}
```