Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/raunaqbhirangi/hiss

Hierarchical State Space Models
https://github.com/raunaqbhirangi/hiss

Last synced: 3 months ago
JSON representation

Hierarchical State Space Models

Awesome Lists containing this project

README

        

# Hierarchical State Space Models (HiSS)
![poster-compressed](https://github.com/raunaqbhirangi/hiss/assets/73357354/33fe0d1d-a1f2-480b-9d5b-ac8318fbbae4)

> __Hierarchical State Space Models for Continuous Sequence-to-Sequence Modeling__ \
> Raunaq Bhirangi, Chenyu Wang, Venkatesh Pattabiraman, Carmel Majidi, Abhinav Gupta, Tess Hellebrekers and Lerrel Pinto\
> Paper: https://arxiv.org/abs/2402.10211 \
> Website: https://hiss-csp.github.io/

## About
HiSS is a simple technique that stacks deep state space models like [S4]() and [Mamba]() to reason over continuous sequences of sensory data over mutiple temporal hierarchies. We also release CSP-Bench: a benchmark for sequence-to-sequence prediction from sensory data.

## Installation
1. Clone the repository

2. Create a conde environment from the provided `env.yml` file: ```conda env create -f env.yml```

3. Install Mamba based on the official [instructions](https://github.com/state-spaces/mamba/tree/main?tab=readme-ov-file#installation).

Note: If you run into CUDA issues while installing Mamba, run ```export CUDA_HOME=$CONDA_PREFIX```, and try again. If you still have problems, install both `causal_conv1d` and `mamba-ssm` from source.

## Data processing
1. Refer to [data_processing/README](./data_processing/README.md) to download and extract the required dataset.

2. Set the `DATA_DIR` variable in the [`hiss/utils/__init__.py`](https://github.com/raunaqbhirangi/hiss/blob/main/hiss/utils/__init__.py) file. This is the path to the parent directory which contains folders corresponding to every dataset.

3. Process the datasets into format compatible with training

__Marker Writing__: `python data_processing/process_reskin_data.py -dd marker_writing__dataset`

__Intrinsic Slip__: `python data_processing/process_reskin_data.py -dd intrinsic_slip__dataset`

__Joystick Control__: `python data_processing/process_xela_data.py -dd joystick_control__dataset`

__RoNIN__: `python data_processing/process_ronin_data.py`

__VECtor__: `python data_processing/process_vector_data.py`

__TotalCapture__: `python data_processing/process_total_capture_data.py`

5. Run `create_dataset.py` for the respective dataset to preprocess data and resample it at the desired frequencies.

__Marker Writing__: `python create_dataset.py --config-name marker_writing_config`

__Intrinsic Slip__: `python create_dataset.py --config-name intrinsic_slip_config`

__Joystick Control__: `python create_dataset.py --config-name joystick_control_config`

__RoNIN__:

`python create_dataset.py --config-name ronin_train_config`

`python create_dataset.py --config-name ronin_test_config`

__VECtor__: `python create_dataset.py --config-name vector_config`

__TotalCapture__:

`python create_dataset.py --config-name total_capture_train_config`

`python create_dataset.py --config-name total_capture_test_config`

## Usage
To train HiSS models for sequential prediction, use the `train.py` file. For each dataset, we provide a `_hiss_config.yaml` file in the `conf/` directory, containing model parameters corresponding to the best-performing HiSS model for the respective dataset. To train the model, simply run

```
python train.py --config-name _hiss_config
```

New datasets can be added by creating a corresponding `Task` object in line with tasks defined in `vt_state/tasks`, and creating a config file in `conf/data_env/`.