Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/pedro-r-marques/keras-list-mapper
A Keras layer that performs a map operation over a ragged tensor
https://github.com/pedro-r-marques/keras-list-mapper
keras keras-layer keras-tensorflow
Last synced: 6 days ago
JSON representation
A Keras layer that performs a map operation over a ragged tensor
- Host: GitHub
- URL: https://github.com/pedro-r-marques/keras-list-mapper
- Owner: pedro-r-marques
- License: apache-2.0
- Created: 2020-11-20T20:20:17.000Z (about 4 years ago)
- Default Branch: main
- Last Pushed: 2021-12-02T10:00:37.000Z (about 3 years ago)
- Last Synced: 2024-11-11T14:44:08.953Z (2 months ago)
- Topics: keras, keras-layer, keras-tensorflow
- Language: Python
- Homepage:
- Size: 43.9 KB
- Stars: 1
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Keras RaggedTensor Mapper
This package implements a Keras layer that applies a map() operation over
one or more RaggedTensors. This is useful when the application processes
sequences of variables lengths.For instance, in an NLP context, it is common to process both small and large
documents. For this type of applications RaggedTensors allow the application
to encode the input data as a variable length list (of pages or N paragraphs).Each of these list elements can then be processed by a neural network that
uses fixed dimension tensors. Often each of these sequence operations wants
to propagate forward state to the next sequence. The ListMapper layer
supports that by allowing the use to define a state vector shape.## Example
```python
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layersfrom keras_list_mapper.keras_list_mapper import ListMapper
class RecurrentCell(layers.Layer):
""" Example recurrent cell
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)def call(self, inputs):
state, features = inputs
nstate = state + features
output = tf.reduce_mean(nstate, axis=-1)
return nstate, outputdef make_model():
inp = layers.Input(shape=(None, 4), ragged=True)
map_fn = RecurrentCell()
m = ListMapper(map_fn, state_shape=(4,))
mr = m(inp)
s = layers.Lambda(lambda x: tf.reduce_sum(x, axis=-1))(mr)
model = keras.Model(inp, s)
model.compile(optimizer="adam", loss="mse")
return modelmodel = make_model()
values = tf.reshape(tf.range(32), (8, 4))
x = tf.RaggedTensor.from_row_lengths(values, [3, 2, 2, 1])
model.predict(x)
```In this example a RecurrentCell is applied over the ragged dimension of the tensor. The current cell performs a computation and stores state in a state vector.
The Ragged Tensor ```x``` has a shape of [4, None, 4]; 4 batches having a sequence length of [3, 2, 2, 1] and then a feature dimension of 4.
The function of the ListMapper is to call the RecurrentCell for the valid
(batch, sequence) pairs, providing an additional state vector per batch.