Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/raunaqbhirangi/hiss
Hierarchical State Space Models
https://github.com/raunaqbhirangi/hiss
Last synced: 3 months ago
JSON representation
Hierarchical State Space Models
- Host: GitHub
- URL: https://github.com/raunaqbhirangi/hiss
- Owner: raunaqbhirangi
- License: mit
- Created: 2024-02-14T18:28:59.000Z (12 months ago)
- Default Branch: main
- Last Pushed: 2024-04-12T13:50:34.000Z (10 months ago)
- Last Synced: 2024-08-01T04:02:09.674Z (6 months ago)
- Language: Python
- Size: 57.6 KB
- Stars: 34
- Watchers: 3
- Forks: 6
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
- Awesome-state-space-models - GitHub
README
# Hierarchical State Space Models (HiSS)
![poster-compressed](https://github.com/raunaqbhirangi/hiss/assets/73357354/33fe0d1d-a1f2-480b-9d5b-ac8318fbbae4)> __Hierarchical State Space Models for Continuous Sequence-to-Sequence Modeling__ \
> Raunaq Bhirangi, Chenyu Wang, Venkatesh Pattabiraman, Carmel Majidi, Abhinav Gupta, Tess Hellebrekers and Lerrel Pinto\
> Paper: https://arxiv.org/abs/2402.10211 \
> Website: https://hiss-csp.github.io/## About
HiSS is a simple technique that stacks deep state space models like [S4]() and [Mamba]() to reason over continuous sequences of sensory data over mutiple temporal hierarchies. We also release CSP-Bench: a benchmark for sequence-to-sequence prediction from sensory data.## Installation
1. Clone the repository2. Create a conde environment from the provided `env.yml` file: ```conda env create -f env.yml```
3. Install Mamba based on the official [instructions](https://github.com/state-spaces/mamba/tree/main?tab=readme-ov-file#installation).
Note: If you run into CUDA issues while installing Mamba, run ```export CUDA_HOME=$CONDA_PREFIX```, and try again. If you still have problems, install both `causal_conv1d` and `mamba-ssm` from source.
## Data processing
1. Refer to [data_processing/README](./data_processing/README.md) to download and extract the required dataset.2. Set the `DATA_DIR` variable in the [`hiss/utils/__init__.py`](https://github.com/raunaqbhirangi/hiss/blob/main/hiss/utils/__init__.py) file. This is the path to the parent directory which contains folders corresponding to every dataset.
3. Process the datasets into format compatible with training
__Marker Writing__: `python data_processing/process_reskin_data.py -dd marker_writing__dataset`
__Intrinsic Slip__: `python data_processing/process_reskin_data.py -dd intrinsic_slip__dataset`
__Joystick Control__: `python data_processing/process_xela_data.py -dd joystick_control__dataset`
__RoNIN__: `python data_processing/process_ronin_data.py`
__VECtor__: `python data_processing/process_vector_data.py`
__TotalCapture__: `python data_processing/process_total_capture_data.py`5. Run `create_dataset.py` for the respective dataset to preprocess data and resample it at the desired frequencies.
__Marker Writing__: `python create_dataset.py --config-name marker_writing_config`
__Intrinsic Slip__: `python create_dataset.py --config-name intrinsic_slip_config`
__Joystick Control__: `python create_dataset.py --config-name joystick_control_config`
__RoNIN__:
`python create_dataset.py --config-name ronin_train_config`
`python create_dataset.py --config-name ronin_test_config`
__VECtor__: `python create_dataset.py --config-name vector_config`
__TotalCapture__:
`python create_dataset.py --config-name total_capture_train_config`
`python create_dataset.py --config-name total_capture_test_config`## Usage
To train HiSS models for sequential prediction, use the `train.py` file. For each dataset, we provide a `_hiss_config.yaml` file in the `conf/` directory, containing model parameters corresponding to the best-performing HiSS model for the respective dataset. To train the model, simply run```
python train.py --config-name _hiss_config
```New datasets can be added by creating a corresponding `Task` object in line with tasks defined in `vt_state/tasks`, and creating a config file in `conf/data_env/`.