https://github.com/i404788/s5-pytorch
Pytorch implementation of Simplified Structured State-Spaces for Sequence Modeling (S5)
https://github.com/i404788/s5-pytorch
pytorch s5 sequence-modeling state-space
Last synced: 3 months ago
JSON representation
Pytorch implementation of Simplified Structured State-Spaces for Sequence Modeling (S5)
- Host: GitHub
- URL: https://github.com/i404788/s5-pytorch
- Owner: i404788
- License: mpl-2.0
- Created: 2023-03-20T23:57:07.000Z (about 3 years ago)
- Default Branch: master
- Last Pushed: 2024-04-26T09:36:59.000Z (about 2 years ago)
- Last Synced: 2025-09-18T06:51:40.530Z (8 months ago)
- Topics: pytorch, s5, sequence-modeling, state-space
- Language: Python
- Homepage:
- Size: 57.6 KB
- Stars: 78
- Watchers: 1
- Forks: 4
- Open Issues: 2
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# S5: Simplified State Space Layers for Sequence Modeling
This is a ported version derived from and .
It includes a bunch of functions ported from jax/lax/flax/whatever since they didn't exist yet.
~~Jax is required because it relies on the pytree structure but it's not used for any computation.~~
Since version 0.2.0 jax is not required, it's using the pytorch native `torch.utils._pytree` (this may be incompatible for pytorch future versions).
Pytorch 2 or later is required because it makes heavy use of `torch.vmap` and `torch.utils._pytree` to substitute it's jax counterpart.
Python 3.10 or later is required due to usage of the `match` keyword
\---
Update:
In my experiments it follows the results found in the [Hyena Hierarchy](https://arxiv.org/abs/2302.10866) (& H3) paper that the state spaces alone lack the recall capabilities required for LLM but seem work well for regular sequence feature extraction and linear complexity.
You can use variable step-size as described in the paper using a 1D tensor for `step_scale` however this takes **a lot of memory** due to a lot of intermediate values needing to be held (which I believe is true for the official S5 repo, but not mentioned in the paper unless I missed it).
## Install
```sh
pip install s5-pytorch
```
## Example
```py3
from s5 import S5, S5Block
# Raw S5 operator
x = torch.rand([2, 256, 32])
model = S5(32, 32)
model(x) # [2, 256, 32]
# S5-former block (S5+FFN-GLU w/ layernorm, dropout & residual)
model = S5Block(32, 32, False)
model(x) # [2, 256, 32]
```