https://github.com/kyegomez/hsss
Implementation of a Hierarchical Mamba as described in the paper: "Hierarchical State Space Models for Continuous Sequence-to-Sequence Modeling"
https://github.com/kyegomez/hsss
ai artificial-intelligence jesus machine-learning ml multi-modal multi-modality open-source pytorch rnn rnns ssms tensorflow zeta
Last synced: 12 days ago
JSON representation
Implementation of a Hierarchical Mamba as described in the paper: "Hierarchical State Space Models for Continuous Sequence-to-Sequence Modeling"
- Host: GitHub
- URL: https://github.com/kyegomez/hsss
- Owner: kyegomez
- License: mit
- Created: 2024-02-16T16:37:19.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2024-11-11T08:46:31.000Z (6 months ago)
- Last Synced: 2025-04-19T20:16:50.498Z (29 days ago)
- Topics: ai, artificial-intelligence, jesus, machine-learning, ml, multi-modal, multi-modality, open-source, pytorch, rnn, rnns, ssms, tensorflow, zeta
- Language: Python
- Homepage: https://discord.gg/jkwyyFdANm
- Size: 2.19 MB
- Stars: 13
- Watchers: 2
- Forks: 2
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- Funding: .github/FUNDING.yml
- License: LICENSE
Awesome Lists containing this project
README
[](https://discord.gg/qUtxnK2NMf)
# HSSS
Implementation of a Hierarchical Mamba as described in the paper: "Hierarchical State Space Models for Continuous Sequence-to-Sequence Modeling" but instead of using traditional SSMs were using Mambas. Basically the flow is single input -> low level mambas -> concat -> high level ssm -> multiple outputs.[Paper link](https://arxiv.org/pdf/2402.10211.pdf)
I believe in this architecture alot as it segments local and global learning.
## install
`pip install hsss`## usage
```python
import torch
from hsss.model import LowLevelMamba, HSSS# Random input text tokens
text = torch.randint(0, 10, (1, 100)).long()# Low level model
mamba = LowLevelMamba(
dim=12, # dimension of input
depth=6, # depth of input
dt_rank=4, # rank of input
d_state=4, # state of input
expand_factor=4, # expansion factor of input
d_conv=6, # convolution dimension of input
dt_min=0.001, # minimum time step of input
dt_max=0.1, # maximum time step of input
dt_init="random", # initialization method of input
dt_scale=1.0, # scaling factor of input
bias=False, # whether to use bias in input
conv_bias=True, # whether to use bias in convolution of input
pscan=True, # whether to use parallel scan in input
)# Low level model 2
mamba2 = LowLevelMamba(
dim=12, # dimension of input
depth=6, # depth of input
dt_rank=4, # rank of input
d_state=4, # state of input
expand_factor=4, # expansion factor of input
d_conv=6, # convolution dimension of input
dt_min=0.001, # minimum time step of input
dt_max=0.1, # maximum time step of input
dt_init="random", # initialization method of input
dt_scale=1.0, # scaling factor of input
bias=False, # whether to use bias in input
conv_bias=True, # whether to use bias in convolution of input
pscan=True, # whether to use parallel scan in input
)# Low level mamba 3
mamba3 = LowLevelMamba(
dim=12, # dimension of input
depth=6, # depth of input
dt_rank=4, # rank of input
d_state=4, # state of input
expand_factor=4, # expansion factor of input
d_conv=6, # convolution dimension of input
dt_min=0.001, # minimum time step of input
dt_max=0.1, # maximum time step of input
dt_init="random", # initialization method of input
dt_scale=1.0, # scaling factor of input
bias=False, # whether to use bias in input
conv_bias=True, # whether to use bias in convolution of input
pscan=True, # whether to use parallel scan in input
)# HSSS
hsss = HSSS(
layers=[mamba, mamba2, mamba3],
num_tokens=10, # number of tokens in model
seq_length=100, # sequence length of model
dim=128, # dimension of model
depth=3, # depth of model
dt_rank=2, # rank of model
d_state=2, # state of model
expand_factor=2, # expansion factor of model
d_conv=3, # convolution dimension of model
dt_min=0.001, # minimum time step of model
dt_max=0.1, # maximum time step of model
dt_init="random", # initialization method of model
dt_scale=1.0, # scaling factor of model
bias=False, # whether to use bias in model
conv_bias=True, # whether to use bias in convolution of model
pscan=True, # whether to use parallel scan in model
proj_layer=True,
)# Forward pass
out = hsss(text)
print(out)```
## Citation
```bibtex
@misc{bhirangi2024hierarchical,
title={Hierarchical State Space Models for Continuous Sequence-to-Sequence Modeling},
author={Raunaq Bhirangi and Chenyu Wang and Venkatesh Pattabiraman and Carmel Majidi and Abhinav Gupta and Tess Hellebrekers and Lerrel Pinto},
year={2024},
eprint={2402.10211},
archivePrefix={arXiv},
primaryClass={cs.LG}
}
```# License
MIT## Todo
- [ ] Implement the chunking of the tokens by spliting it up the sequence dimension
- [ ] Make the fusion projection layer dynamic and not use just a linear, ffn, or cross attention or even an output head.