https://github.com/titu1994/tf-sha-rnn
Tensorflow port implementation of Single Headed Attention RNN
https://github.com/titu1994/tf-sha-rnn
Last synced: 6 months ago
JSON representation
Tensorflow port implementation of Single Headed Attention RNN
- Host: GitHub
- URL: https://github.com/titu1994/tf-sha-rnn
- Owner: titu1994
- License: mit
- Created: 2020-02-01T03:36:59.000Z (over 5 years ago)
- Default Branch: master
- Last Pushed: 2020-02-01T03:53:01.000Z (over 5 years ago)
- Last Synced: 2025-03-25T05:34:07.126Z (7 months ago)
- Language: Python
- Size: 15.6 KB
- Stars: 16
- Watchers: 3
- Forks: 1
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Single Headed Attention RNN for Tensorflow 2.0
For full details see the paper [Single Headed Attention RNN: Stop Thinking With Your Head](https://arxiv.org/abs/1911.11423).Code ported from author's implementation here - https://github.com/Smerity/sha-rnn
# Usage
The `SHARNN` model class is a direct port in the most part of the codebase written in PyTorch.In Tensorflow, it can be used either directly as a Keras Model, added as a sublayer of another Model. The model can be traced by tf.function, so performance degredation should be minimum even when custom training loops are being used.
## As a Keras Model
```python
from sharnn import SHARNNmodel = SHARNN(num_token=1000, embed_dim=100, num_hid=200, num_layers=2,
return_hidden=True, return_mem=True)model.compile(optimizer='adam', loss='mse')
# Test predict
model.predict(x)model.summary()
```## Inside a custom training loop
```python
@tf.function
def model_forward_with_grads(model, x):
with tf.GradientTape() as tape:
h, new_hidden, new_mems = model(x, training=True)
h, new_hidden, new_mems = model(x, hidden=new_hidden, mems=new_mems, training=True)loss = tf.reduce_sum(h) # Just for demonstration purposes
grad = tape.gradient(loss, model.trainable_variables)
return loss, grad
```# Caveats
There is currently an issue with setting a maximum of the number of positions in `mems` (see TODO). Therefore there is currently no limit on the amount of memory that `mems` can take.# Requirements
- Tensorflow 2.0+