{"id":13444643,"url":"https://github.com/state-spaces/s4","last_synced_at":"2025-12-15T03:03:31.630Z","repository":{"id":37421293,"uuid":"424280831","full_name":"state-spaces/s4","owner":"state-spaces","description":"Structured state space sequence models","archived":false,"fork":false,"pushed_at":"2024-07-17T17:04:39.000Z","size":44199,"stargazers_count":2621,"open_issues_count":53,"forks_count":322,"subscribers_count":52,"default_branch":"main","last_synced_at":"2025-05-15T03:03:49.523Z","etag":null,"topics":["pytorch","sequence-models","state-space-models"],"latest_commit_sha":null,"homepage":"","language":"Jupyter Notebook","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/state-spaces.png","metadata":{"files":{"readme":"README.md","changelog":"CHANGELOG.md","contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2021-11-03T15:33:53.000Z","updated_at":"2025-05-15T02:50:31.000Z","dependencies_parsed_at":"2023-01-21T12:00:11.897Z","dependency_job_id":"b6c8aa23-b87a-4f0d-a81f-641a5547621e","html_url":"https://github.com/state-spaces/s4","commit_stats":null,"previous_names":["state-spaces/s4","hazyresearch/state-spaces"],"tags_count":5,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/state-spaces%2Fs4","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/state-spaces%2Fs4/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/state-spaces%2Fs4/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/state-spaces%2Fs4/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/state-spaces","download_url":"https://codeload.github.com/state-spaces/s4/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":254264765,"owners_count":22041793,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["pytorch","sequence-models","state-space-models"],"created_at":"2024-07-31T04:00:32.798Z","updated_at":"2025-10-05T09:28:33.867Z","avatar_url":"https://github.com/state-spaces.png","language":"Jupyter Notebook","readme":"# Structured State Spaces for Sequence Modeling\n\nThis repository provides the official implementations and experiments for models related to [S4](https://arxiv.org/abs/2111.00396),\nincluding [HiPPO](https://arxiv.org/abs/2008.07669), [LSSL](https://arxiv.org/abs/2110.13985), [SaShiMi](https://arxiv.org/abs/2202.09729),\n[DSS](https://arxiv.org/abs/2203.14343), [HTTYH](https://arxiv.org/abs/2206.12037), [S4D](https://arxiv.org/abs/2206.11893),\nand [S4ND](https://arxiv.org/abs/2210.06583).\n\nProject-specific information for each of these models, including overview of the source code and specific experiment reproductions,\ncan be found under [models/](models/).\n\n\n## Table of Contents\n\nSetting up the environment and porting S4 to external codebases:\n- [Setup](#setup)\n- [Getting Started with S4](#getting-started-with-s4)\n\nUsing this repository for training models:\n- [Training](#training)\n- [Generation](#generation)\n- [Repository Structure](#overall-repository-structure)\n- [Citation](#citation)\n\n### Changelog\nSee [CHANGELOG.md](CHANGELOG.md)\n\n### Roadmap\n- More documentation for training from scratch using this repository\n- Compilation of S4 resources and implementations\n- pip package\n\n\n## Setup\n\n### Requirements\nThis repository requires Python 3.9+ and Pytorch 1.10+.\nIt has been tested up to Pytorch 1.13.1.\nOther packages are listed in [requirements.txt](./requirements.txt).\nSome care may be needed to make some of the library versions compatible, particularly torch/torchvision/torchaudio/torchtext.\n\nExample installation:\n```\nconda install pytorch==1.13.1 torchvision==0.14.1 torchaudio==0.13.1 pytorch-cuda=11.6 -c pytorch -c nvidia\npip install -r requirements.txt\n```\n\n\n### Structured Kernels\n\nA core operation of S4 are the Cauchy and Vandermonde kernels described in the [paper](https://arxiv.org/abs/2111.00396).\nThese are very simple matrix multiplications; a naive implementation of these operation can be found in the [standalone](models/s4/s4.py) in the function `cauchy_naive` and `log_vandermonde_naive`.\nHowever, as the paper describes, this has suboptimal memory usage that currently requires a custom kernel to overcome in PyTorch.\n\nTwo more efficient methods are supported. The code will automatically detect if either of these is installed and call the appropriate kernel.\n\n#### Custom CUDA Kernel\n\nThis version is faster but requires manual compilation for each machine environment.\nRun `python setup.py install` from the directory `extensions/kernels/`.\n\n#### Pykeops\n\nThis version is provided by the [pykeops library](https://www.kernel-operations.io/keops/python/installation.html).\nInstallation usually works out of the box with `pip install pykeops cmake` which are also listed in the requirements file.\n\n\n## Getting Started with S4\n\n### S4 Module\n\nSelf-contained files for the S4 layer and variants can be found in [models/s4/](./models/s4/),\nwhich includes instructions for calling the module.\n\nSee [notebooks/](notebooks/) for visualizations explaining some concepts behind HiPPO and S4.\n\n### Example Train Script (External Usage)\n\n[example.py](example.py) is a self-contained training script for MNIST and CIFAR that imports the standalone S4 file. The default settings `python example.py` reaches 88% accuracy on sequential CIFAR with a very simple S4D model of 200k parameters.\nThis script can be used as an example for using S4 variants in external repositories.\n\n### Training with this Repository (Internal Usage)\n\nThis repository aims to provide a very flexible framework for training sequence models. Many models and datasets are supported.\n\nThe basic entrypoint is `python -m train`, or equivalently\n```\npython -m train pipeline=mnist model=s4\n```\nwhich trains an S4 model on the Permuted MNIST dataset.\nThis should get to around 90% after 1 epoch which takes 1-3 minutes depending on GPU.\n\nMore examples of using this repository are documented throughout. See [Training](#training) for an overview.\n\n### Optimizer Hyperparameters\n\nOne important feature of this codebase is supporting parameters that require different optimizer hyperparameters.\nIn particular, the SSM kernel is particularly sensitive to the $(A, B)$ (and sometimes $\\Delta$ parameters),\nso the learning rate on these parameters is sometimes lowered and the weight decay is always set to $0$.\n\nSee the method `register` in the model (e.g. [s4d.py](py)) and the function `setup_optimizer` in the training script (e.g. [example.py](example.py)) for an examples of how to implement this in external repos.\n\n\u003c!--\nOur logic for setting these parameters can be found in the `OptimModule` class under `src/models/sequence/ss/kernel.py` and the corresponding optimizer hook in `SequenceLightningModule.configure_optimizers` under `train.py`\n--\u003e\n\n\n## Training\n\nThe core training infrastructure of this repository is based on [Pytorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/) with a configuration scheme based on [Hydra](https://hydra.cc/docs/intro/).\n\nThe main entrypoint is `train.py` and configs are found in `configs/`.\n\n### Data\n\nBasic datasets are auto-downloaded, including MNIST, CIFAR, and Speech Commands.\nAll logic for creating and loading datasets is in [src/dataloaders](./src/dataloaders/) directory.\nThe README inside this subdirectory documents how to download and organize other datasets.\n\n### Models\n\nModels are defined in [src/models](src/models). See the README in this subdirectory for an overview.\n\n\n### Configs and Hyperparameters\nPre-defined configs reproducing end-to-end experiments from the papers are provided, found under project-specific information in [models/](models/), such as for the [original S4 paper](models/s4/experiments.md).\n\nConfigs can also be easily modified through the command line.\nAn example experiment is\n```\npython -m train pipeline=mnist dataset.permute=True model=s4 model.n_layers=3 model.d_model=128 model.norm=batch model.prenorm=True wandb=null\n```\nThis uses the Permuted MNIST task with an S4 model with a specified number of layers, backbone dimension, and normalization type.\n\nSee [configs/README.md](configs/) for more detailed documentation about the configs.\n\n\n#### Hydra\n\nIt is recommended to read the [Hydra documentation](https://hydra.cc/docs/intro/) to fully understand the configuration framework. For help launching specific experiments, please file an issue.\n\n\u003c!--\n#### Registries\n\nThis codebase uses a modification of the hydra `instantiate` utility that provides shorthand names of different classes, for convenience in configuration and logging.\nThe mapping from shorthand to full path can be found in `src/utils/registry.py`.\n--\u003e\n\n\n### Resuming\n\nEach experiment will be logged to its own directory (generated by Hydra) of the form `./outputs/\u003cdate\u003e/\u003ctime\u003e/`. Checkpoints will be saved here inside this folder and printed to console whenever a new checkpoint is created.\nTo resume training, simply point to the desired `.ckpt` file (a PyTorch Lightning checkpoint, e.g. `./outputs/\u003cdate\u003e/\u003ctime\u003e/checkpoints/val/loss.ckpt`) and append the flag `train.ckpt=\u003cpath\u003e/\u003cto\u003e/\u003ccheckpoint\u003e.ckpt` to the original training command.\n\n### PyTorch Lightning Trainer\n\nThe PTL [Trainer](https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html) class controls the overall training loop and also provides many useful pre-defined flags. Some useful examples are explained below.\nThe full list of allowable flags can be found in the PTL documentation, as well as our [trainer configs](configs/trainer/). See the default trainer config [configs/trainer/default.yaml](configs/trainer/default.yaml) for the most useful options.\n\n#### Multi-GPU training\n\nSimply pass in `trainer.gpus=2` to train with 2 GPUs.\n\n#### Inspect model layers\n\n`trainer.weights_summary=full` prints out every layer of the model with their parameter counts. Useful for debugging internals of models.\n\n#### Data subsampling\n`trainer.limit_{train,val}_batches={10,0.1}` trains (validates) on only 10 batches (0.1 fraction of all batches). Useful for testing the train loop without going through all the data.\n\n\n### WandB\n\nLogging with [WandB](https://wandb.ai/site) is built into this repository.\nIn order to use this, simply set your `WANDB_API_KEY` environment variable, and change the `wandb.project` attribute of [configs/config.yaml](configs/config.yaml) (or pass it on the command line e.g. `python -m train .... wandb.project=s4`).\n\nSet `wandb=null` to turn off WandB logging.\n\n\n## Generation\n\nAutoregressive generation can be performed with the [generate.py](generate.py) script.\nThis script can be used in two ways after training a model using this codebase.\n\n### Option 1: Checkpoint Path\nThe more flexible option requires the checkpoint path of the trained PyTorch Lightning model.\nThe generation script accepts the same config options as the train script, with a few additional flags that are documented in [configs/generate.yaml](configs/generate.yaml).\nAfter training with `python -m train \u003ctrain flags\u003e`, generate with\n```\npython -m generate \u003ctrain flags\u003e checkpoint_path=\u003cpath/to/model.ckpt\u003e \u003cgeneration flags\u003e\n```\nAny of the flags found in the config can be overridden.\n\nNote: This option can be used with either `.ckpt` checkpoints (PyTorch Lightning, which includes information for the Trainer) or `.pt` checkpoints (PyTorch, which is just a model state dict).\n\n### Option 2: Experiment Path\nThe second option for generation does not require passing in training flags again, and instead reads the config from the Hydra experiment folder, along with a PyTorch Lightning checkpoint within the experiment folder.\n\n### Example 1 (Language)\n\nDownload the [WikiText-103 model checkpoint](https://huggingface.co/krandiash/sashimi-release/tree/main/checkpoints), for example to `./checkpoints/s4-wt103.pt`.\nThis model was trained with the command `python -m train experiment=lm/s4-wt103`. Note that from the config we can see that the model was trained with a receptive field of length 8192.\n\nTo generate, run\n```\npython -m generate experiment=lm/s4-wt103 checkpoint_path=checkpoints/s4-wt103.pt n_samples=1 l_sample=16384 l_prefix=8192 decode=text\n```\nThis generates a sample of length 16384 conditioned on a prefix of length 8192.\n\n### Example 2 (Audio)\n\nLet's train a small SaShiMi model on the SC09 dataset. We can also reduce the number of training and validation batches to get a checkpoint faster:\n```\npython -m train experiment=audio/sashimi-sc09 model.n_layers=2 trainer.limit_train_batches=0.1 trainer.limit_val_batches=0.1\n```\n\nAfter the first epoch completes, a message is printed indicating where the checkpoint is saved.\n```\nEpoch 0, global step 96: val/loss reached 3.71754 (best 3.71754), saving model to \"\u003crepository\u003e/outputs/\u003cdate\u003e/\u003ctime\u003e/checkpoints/val/loss.ckpt\"\n```\n\nOption 1:\n```\npython -m generate experiment=audio/sashimi-sc09 model.n_layers=2 checkpoint_path=\u003crepository\u003e/outputs/\u003cdate\u003e/\u003ctime\u003e/checkpoints/val/loss.ckpt n_samples=4 l_sample=16000\n```\nThis option redefines the full config so that the model and dataset can be constructed.\n\nOption 2:\n```\npython -m generate experiment_path=\u003crepository\u003e/outputs/\u003cdate\u003e/\u003ctime\u003e checkpoint_path=checkpoints/val/loss.ckpt n_samples=4 l_sample=16000\n```\nThis option only needs the path to the Hydra experiment folder and the desired checkpoint within.\n\n\n## Overall Repository Structure\n```\nconfigs/         Config files for model, data pipeline, training loop, etc.\ndata/            Default location of raw data\nextensions/      CUDA extensions (Cauchy and Vandermonde kernels)\nsrc/             Main source code for models, datasets, etc.\n  callbacks/     Training loop utilities (e.g. checkpointing)\n  dataloaders/   Dataset and dataloader definitions\n  models/        Model definitions\n  tasks/         Encoder/decoder modules to interface between data and model backbone\n  utils/\nmodels/          Model-specific information (code, experiments, additional resources)\nexample.py       Example training script for using S4 externally\ntrain.py         Training entrypoint for this repo\ngenerate.py      Autoregressive generation script\n```\n\n\n## Citation\nIf you use this codebase, or otherwise found our work valuable, please cite S4 and [other relevant papers](models/README.md#citations).\n\n```\n@inproceedings{gu2022efficiently,\n  title={Efficiently Modeling Long Sequences with Structured State Spaces},\n  author={Gu, Albert and Goel, Karan and R\\'e, Christopher},\n  booktitle={The International Conference on Learning Representations ({ICLR})},\n  year={2022}\n}\n```\n","funding_links":[],"categories":["Before 2023","Jupyter Notebook","People and works"],"sub_categories":["Interesting GitHub Repositories"],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fstate-spaces%2Fs4","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fstate-spaces%2Fs4","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fstate-spaces%2Fs4/lists"}