https://github.com/wizyoung/minimal-rnn-tensorflow
Implementation of Minimal RNN
https://github.com/wizyoung/minimal-rnn-tensorflow
minimalrnn rnn tensorflow
Last synced: 2 months ago
JSON representation
Implementation of Minimal RNN
- Host: GitHub
- URL: https://github.com/wizyoung/minimal-rnn-tensorflow
- Owner: wizyoung
- Created: 2018-05-29T02:46:43.000Z (about 8 years ago)
- Default Branch: master
- Last Pushed: 2018-09-16T14:01:50.000Z (almost 8 years ago)
- Last Synced: 2025-12-27T13:31:27.122Z (6 months ago)
- Topics: minimalrnn, rnn, tensorflow
- Language: Python
- Size: 285 KB
- Stars: 2
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# Minimal-RNN-TensorFlow
This is the TensorFlow implementation of the paper: [MinimalRNN: Toward More Interpretable and Trainable Recurrent Neural Networks](https://arxiv.org/abs/1711.06788) by Minmin Chen in NIPS 2017.

### Usage
The usage is quite simple as the API of the Minimal RNN layer is totally the same with other RNN layers (like LSTM, GRU): Just `from rnn_cell import MinimalRNNCell` and use the standard TensorFlow RNN layer API.
An example code (Multi RNN example):
```python
import tensorflow as tf
from rnn_cell import MinimalRNNCell
# input_shape: [batch_size, seq_length, feat_dim]
input = tf.placeholder(tf.float32, [160, 100, 1024], name='inputs')
def get_rnn_cell():
return MinimalRNNCell(num_units=128, kernel_initializer=tf.orthogonal_initializer())
multi_rnn_cell_video = tf.contrib.rnn.MultiRNNCell([get_rnn_cell() for _ in range(2)], state_is_tuple=True)
initial_state = multi_rnn_cell_video.zero_state(batch_size=160, dtype=tf.float32)
rnn_outputs, state = tf.nn.dynamic_rnn(
cell=multi_rnn_cell_video,
inputs=input,
initial_state=initial_state,
dtype=tf.float32
)
print(rnn_outputs)
print(state)
```
output:
```
Tensor("rnn/transpose_1:0", shape=(160, 100, 128), dtype=float32)
(, )
```
So the usage is totally the same with other RNN layers like GRU!
### NOTE
The RNN layer cells (including LSTM, GRU) in TensorFlow are defined in [tensorflow/python/ops/rnn_cell_impl.py](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/rnn_cell_impl.py). The Minimal RNN layer in this repo is inherited from the RNNCell in that file to have the consistent API. Note that the API of the RNN layer cells in TensorFlow has changed a lot after version 1.4, so I implement two versions of Minimal RNN layers corresponding to TensorFlow version <=1.4 and TensorFlow version > 1.4 for compatibility. And the version switch is performed automatically so you don't need to worry about that.