{"id":15600962,"url":"https://github.com/lucidrains/gateloop-transformer","last_synced_at":"2025-04-13T04:16:08.885Z","repository":{"id":205876075,"uuid":"715305871","full_name":"lucidrains/gateloop-transformer","owner":"lucidrains","description":"Implementation of GateLoop Transformer in Pytorch and Jax","archived":false,"fork":false,"pushed_at":"2024-06-18T21:07:39.000Z","size":36057,"stargazers_count":87,"open_issues_count":0,"forks_count":11,"subscribers_count":11,"default_branch":"main","last_synced_at":"2025-04-13T04:15:50.341Z","etag":null,"topics":["artificial-intelligence","deep-learning","sequence-modeling"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/lucidrains.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2023-11-06T21:56:40.000Z","updated_at":"2025-01-08T04:37:58.000Z","dependencies_parsed_at":null,"dependency_job_id":"f27ccbd2-612f-478a-b226-ad3bf3a4d91f","html_url":"https://github.com/lucidrains/gateloop-transformer","commit_stats":null,"previous_names":["lucidrains/gateloop-pytorch"],"tags_count":35,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Fgateloop-transformer","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Fgateloop-transformer/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Fgateloop-transformer/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2Fgateloop-transformer/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/lucidrains","download_url":"https://codeload.github.com/lucidrains/gateloop-transformer/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":248661719,"owners_count":21141451,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["artificial-intelligence","deep-learning","sequence-modeling"],"created_at":"2024-10-03T02:10:25.288Z","updated_at":"2025-04-13T04:16:08.853Z","avatar_url":"https://github.com/lucidrains.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"\u003cimg src=\"./gateloop.png\" width=\"450px\"\u003e\u003c/img\u003e\n\n## GateLoop Transformer\n\nImplementation of \u003ca href=\"https://arxiv.org/abs/2311.01927\"\u003eGateLoop\u003c/a\u003e Transformer in Pytorch and Jax, to be tested on Enwik8 character level modeling.\n\nUpdate: A transformer run with regular attention + data dependent xpos relative positions did not converge at all. Also, gate loop's associative scan also is not able to train on even sequence lengths of 128. I'm not sure if it can be done without a specialized CUDA kernel, much like autoregressive linear attention (RWKV and the like)\n\nUpdate 2: Got a smaller GateLoop transformer (gate loop dimensions of 128) to run on sequence length of 256. It is converging very well with a quick eyeball. Will run some more rigorous experiments tomorrow.\n\nUpdate 3: Fixed a misunderstanding and definitely seems to be converging better than vanilla linear attention (from my memories of those experiments).\n\nUpdate 4: \u003ca href=\"https://api.wandb.ai/links/lucidrains/ysbz84fn\"\u003eOngoing experiments\u003c/a\u003e\n\nUpdate 5: Author has reviewed the code, and there was another misunderstanding. They use maximum heads (heads == dimension). This is kind of a plot twist, as this is infeasible for normal attention. It also obviates the need a fused CUDA kernel as in autoregressive linear attention.\n\nUpdate 6: Corrected gateloop transformer run looks amazing. Cautiously optimistic now.\n\nUpdate 7: Ablating state transition shows expected negative result. Ablating complex valued states though, I see no difference, at least, early in the run.\n\nUpdate 8: Directly projecting to `kv` with one projection for the max-heads setting (instead of keys and values separately followed by element-wise multiplication) yields similar results\n\nUpdate 9: \u003ca href=\"https://api.wandb.ai/links/lucidrains/do1i9rx0\"\u003eHead to head to 20k\u003c/a\u003e, just to make sure Gateloop doesn't get exceeded later on\n\nUpdate 10: and it got passed by attention, at least, assuming the implementation in the repo is correct.\n\nUpdate 11: I'm seeing a steady improvement increasing the head dimension, so I no longer believe max-heads is optimal. Increasing the head dimension brings us right back to linear attention and needing the fused CUDA kernel.\n\nUpdate 12: \u003ca href=\"https://github.com/cnapun\"\u003eNikil\u003c/a\u003e spotted a potential error with the `kv` not being kept in complex (and real component taken at end). \u003ca href=\"https://api.wandb.ai/links/lucidrains/lgz368mf\"\u003eRerunning experiments\u003c/a\u003e\n\nUpdate 13: Still clearly worse\n\nUpdate 14: See some synergy when mixing gateloop and attention on a small scale, when holding parameters constant. Will be adding a tiny bit of simplified gateloop layers to transformers to address a main weakness in attention for future projects.\n\nUpdate 15: There may be a way to combine associative scan based works with the findings from the recently proposed \u003ca href=\"https://arxiv.org/abs/2312.04927\"\u003etaylor series linear attention\u003c/a\u003e. will carry out some independent research before end of January 2024 and share the results here.\n\n### Appreciation\n\n- \u003ca href=\"https://stability.ai/\"\u003eStabilityAI\u003c/a\u003e, \u003ca href=\"https://a16z.com/supporting-the-open-source-ai-community/\"\u003eA16Z Open Source AI Grant Program\u003c/a\u003e, and \u003ca href=\"https://huggingface.co/\"\u003e🤗 Huggingface\u003c/a\u003e for the generous sponsorships, as well as my other sponsors, for affording me the independence to open source current artificial intelligence research\n\n### Install\n\n```bash\n$ pip install gateloop-transformer\n```\n\n### Usage\n\n```python\nimport torch\nfrom gateloop_transformer import Transformer\n\nmodel = Transformer(\n    num_tokens = 256,\n    dim = 624,\n    depth = 6,\n    use_gate_looped_attn = True\n)\n\nids = torch.randint(0, 256, (1, 1024))\nlogits = model(ids) # (1, 1024, 256)\n```\n\nA simplified gate loop layer\n\n```python\nimport torch\nfrom gateloop_transformer import SimpleGateLoopLayer\n\ngateloop = SimpleGateLoopLayer(512)\n\nx = torch.randn(1, 65536, 512)\nx = gateloop(x) + x\n```\n### Character-level Language Modeling\n\nInstall requirements\n\n```bash\n$ pip install -r requirements.txt\n```\n\nThen run the `train.py` script for autoregressive modeling on enwik8\n\n```bash\n$ python train.py\n```\n\n### Todo\n\n- [x] jax version with equinox\n- [x] start with naive memory checkpointing of gate loop operation\n- [x] retry the failed full attention experiments (with data dependent xpos), but with complex valued scales (didn't work)\n- [x] separate out a minimal gateloop circuit, to augment attention, rather than to replace it, as done in \u003ca href=\"https://arxiv.org/abs/2209.10655\"\u003eMega\u003c/a\u003e\n- [x] experiments\n    - [x] do all the ablations and figure out how much the data controlled state transitions adds (as well as whether it needs to be complex)\n    - [x] do complete runs between transformer + rotary against gateloop with max heads, parameter held constant to 20k steps\n- [x] just use jax's associative scan, wrapped with jax2torch, for now. pytorch team claim they will implement \u003ca href=\"https://github.com/pytorch/pytorch/issues/95408\"\u003ethis\u003c/a\u003e eventually\n\n## Citations\n\n```bibtex\n@inproceedings{Katsch2023GateLoopFD,\n    title   = {GateLoop: Fully Data-Controlled Linear Recurrence for Sequence Modeling},\n    author  = {Tobias Katsch},\n    year    = {2023},\n    url     = {https://api.semanticscholar.org/CorpusID:265018962}\n}\n```\n\n```bibtex\n@inproceedings{Heinsen2023EfficientPO,\n    title   = {Efficient Parallelization of a Ubiquitous Sequential Computation},\n    author  = {Franz A. Heinsen},\n    year    = {2023},\n    url     = {https://api.semanticscholar.org/CorpusID:265213659}\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flucidrains%2Fgateloop-transformer","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Flucidrains%2Fgateloop-transformer","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flucidrains%2Fgateloop-transformer/lists"}