{"id":13444413,"url":"https://srush.github.io/annotated-s4/","last_synced_at":"2025-03-20T18:32:53.926Z","repository":{"id":38037151,"uuid":"436415334","full_name":"srush/annotated-s4","owner":"srush","description":"Implementation of https://srush.github.io/annotated-s4","archived":false,"fork":false,"pushed_at":"2023-02-01T20:01:28.000Z","size":85178,"stargazers_count":485,"open_issues_count":7,"forks_count":62,"subscribers_count":9,"default_branch":"main","last_synced_at":"2025-03-09T07:16:31.566Z","etag":null,"topics":["deep-learning","jax"],"latest_commit_sha":null,"homepage":"https://srush.github.io/annotated-s4","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/srush.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2021-12-08T22:52:35.000Z","updated_at":"2025-03-01T10:34:37.000Z","dependencies_parsed_at":"2023-02-17T10:30:23.745Z","dependency_job_id":null,"html_url":"https://github.com/srush/annotated-s4","commit_stats":null,"previous_names":[],"tags_count":1,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/srush%2Fannotated-s4","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/srush%2Fannotated-s4/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/srush%2Fannotated-s4/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/srush%2Fannotated-s4/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/srush","download_url":"https://codeload.github.com/srush/annotated-s4/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":244670690,"owners_count":20491040,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["deep-learning","jax"],"created_at":"2024-07-31T04:00:22.352Z","updated_at":"2025-03-20T18:32:49.443Z","avatar_url":"https://github.com/srush.png","language":"Python","readme":"\n* **[Link To The Blog Post](https://srush.github.io/annotated-s4)**\n\n\n\u003ca href=\"https://srush.github.io/annotated-s4\"\u003e\u003cimg src=\"https://user-images.githubusercontent.com/35882/149201164-1723a44a-f34b-467c-94b0-ffda5ebcabbb.png\"\u003e\u003c/a\u003e\n\n\n\n## Experiments\n\n#### MNIST Sequence Modeling\n\n```bash\npython -m s4.train dataset=mnist layer=s4 train.epochs=100 train.bsz=128 model.d_model=128 model.layer.N=64\n```\n\nThe following command uses a larger model (5M params) and logs generated samples to wandb every epoch. It achieves 0.36 test NLL (0.52 bits per dimension), a state-of-the-art on this task.\n```bash\npython -m s4.train dataset=mnist layer=s4 train.epochs=100 train.bsz=50 train.lr=5e-3 train.lr_schedule=true model.d_model=512 model.n_layers=6 model.dropout=0.0 train.weight_decay=0.05 model.prenorm=true model.embedding=true wandb.mode=online train.sample=308 \n```\n\n#### QuickDraw Sequence Modeling\n\n```bash\n# Default arguments\npython -m s4.train dataset=quickdraw layer=s4 train.epochs=10 train.bsz=128 model.d_model=128 model.layer.N=64\n\n# \"Run in a day\" variant\npython -m s4.train dataset=quickdraw layer=s4 train.epochs=1 train.bsz=512 model.d_model=256 model.layer.N=64 model.dropout=0.05\n```\n\n#### MNIST Classification\n\n```bash\npython -m s4.train dataset=mnist-classification layer=s4 train.epochs=20 train.bsz=128 model.d_model=128 model.dropout=0.25 train.lr=5e-3 train.lr_schedule=true seed=1\n```\n\nGets \"best\" 99.55% accuracy after 20 epochs @ 17s/epoch on an A100\n\n#### CIFAR-10 Classification\n\n```\npython -m s4.train dataset=cifar-classification layer={s4,dss,s4d} train.epochs=100 train.bsz=50 model.n_layers=6 model.d_model=512 model.dropout=0.25 train.lr=5e-3 train.weight_decay=0.01 train.lr_schedule=true seed=1\n```\n\nS4 gets \"best\" 91.23% accuracy after 100 epochs @ 2m16s/epoch on an A100\n\nDSS gets \"best\" 89.31% accuracy after 100 epochs @ 1m41s/epoch on an A100\n\nS4D gets \"best\" 89.76% accuracy after 100 epochs @ 1m32s/epoch on an A100\n\nThe alternative S4D-Lin initialization performs slightly better with 90.98% accuracy.\n\n```\npython -m s4.train dataset=cifar-classification layer=s4d train.epochs=100 train.bsz=50 model.n_layers=6 model.d_model=512 model.dropout=0.25 train.lr=5e-3 train.weight_decay=0.01 train.lr_schedule=true seed=1 +model.layer.scaling=linear\n```\n\n\n---\n\n## Quickstart (Development)\n\nWe have two `requirements.txt` files that hold dependencies for the current project: one that is tailored to CPUs,\nthe other that installs for GPU.\n\n### CPU-Only (MacOS, Linux)\n\n```bash\n# Set up virtual/conda environment of your choosing \u0026 activate...\npip install -r requirements-cpu.txt\n\n# Set up pre-commit\npre-commit install\n```\n\n### GPU (CUDA \u003e 11 \u0026 CUDNN \u003e 8.2)\n\n```bash\n# Set up virtual/conda environment of your choosing \u0026 activate...\npip install -r requirements-gpu.txt\n\n# Set up pre-commit\npre-commit install\n```\n\n## Dependencies from Scratch\n\nIn case the above `requirements.txt` don't work, here are the commands used to download dependencies.\n\n### CPU-Only\n\n```bash\n# Set up virtual/conda environment of your choosing \u0026 activate... then install the following:\npip install --upgrade \"jax[cpu]\"\npip install flax\npip install torch torchvision torchaudio\n\n# Defaults\npip install black celluloid flake8 google-cloud-storage isort ipython matplotlib pre-commit seaborn tensorflow tqdm\n\n# Set up pre-commit\npre-commit install\n```\n\n### GPU (CUDA \u003e 11, CUDNN \u003e 8.2)\n\nNote - CUDNN \u003e 8.2 is critical for compilation without warnings, and GPU w/ at least Turing architecture for full\nefficiency.\n\n```bash\n# Set up virtual/conda environment of your choosing \u0026 activate... then install the following:\npip install jax[cuda11_cudnn82] -f https://storage.googleapis.com/jax-releases/jax_releases.html\npip install flax\npip install torch==1.10.1+cpu torchvision==0.11.2+cpu torchaudio==0.10.1+cpu -f https://download.pytorch.org/whl/cpu/torch_stable.html\n\n# Defaults\npip install black celluloid flake8 google-cloud-storage isort ipython matplotlib pre-commit seaborn tensorflow tqdm\n\n# Set up pre-commit\npre-commit install\n```\n","funding_links":[],"categories":["Tutorials","Tutorials \u003ca name=\"tutorials\"\u003e\u003c/a\u003e","Before 2023","Publications, Annotations and Visualizations"],"sub_categories":["Blogs"],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/srush.github.io%2Fannotated-s4%2F","html_url":"https://awesome.ecosyste.ms/projects/srush.github.io%2Fannotated-s4%2F","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/srush.github.io%2Fannotated-s4%2F/lists"}