{"id":13752864,"url":"https://github.com/lucidrains/FLASH-pytorch","last_synced_at":"2025-05-09T20:34:25.935Z","repository":{"id":38228657,"uuid":"475150006","full_name":"lucidrains/FLASH-pytorch","owner":"lucidrains","description":"Implementation of the Transformer variant proposed in \"Transformer Quality in Linear Time\"","archived":false,"fork":false,"pushed_at":"2023-09-26T00:14:09.000Z","size":35851,"stargazers_count":362,"open_issues_count":7,"forks_count":25,"subscribers_count":8,"default_branch":"main","last_synced_at":"2025-05-06T14:16:49.506Z","etag":null,"topics":["artificial-intelligence","attention-mechanism","deep-learning","efficient-transformers","transformers"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/lucidrains.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2022-03-28T19:28:39.000Z","updated_at":"2025-04-25T07:17:49.000Z","dependencies_parsed_at":"2023-09-26T05:14:35.216Z","dependency_job_id":null,"html_url":"https://github.com/lucidrains/FLASH-pytorch","commit_stats":{"total_commits":29,"total_committers":1,"mean_commits":29.0,"dds":0.0,"last_synced_commit":"eec13388f329d7dbbdb5451fd7ff8302ba571331"},"previous_names":[],"tags_count":26,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2FFLASH-pytorch","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2FFLASH-pytorch/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2FFLASH-pytorch/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/lucidrains%2FFLASH-pytorch/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/lucidrains","download_url":"https://codeload.github.com/lucidrains/FLASH-pytorch/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":253321772,"owners_count":21890462,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["artificial-intelligence","attention-mechanism","deep-learning","efficient-transformers","transformers"],"created_at":"2024-08-03T09:01:11.908Z","updated_at":"2025-05-09T20:34:20.898Z","avatar_url":"https://github.com/lucidrains.png","language":"Python","funding_links":[],"categories":["Transformer库与优化"],"sub_categories":[],"readme":"\u003cimg src=\"./flash.png\" width=\"500px\"\u003e\u003c/img\u003e\n\n## FLASH - Pytorch\n\nImplementation of the Transformer variant proposed in the paper \u003ca href=\"https://arxiv.org/abs/2202.10447\"\u003eTransformer Quality in Linear Time\u003c/a\u003e\n\n## Install\n\n```bash\n$ pip install FLASH-pytorch\n```\n\n## Usage\n\nThe main novel circuit in this paper is the \"Gated Attention Unit\", which they claim can replace multi-headed attention while reducing it to just one head.\n\nIt uses a relu squared activation in place of the softmax, the activation of which was first seen in the \u003ca href=\"https://arxiv.org/abs/2109.08668\"\u003ePrimer paper\u003c/a\u003e, and the use of ReLU in \u003ca href=\"https://arxiv.org/abs/2104.07012\"\u003eReLA Transformer\u003c/a\u003e. The gating style seems mostly inspired by \u003ca href=\"https://arxiv.org/abs/2105.08050\"\u003egMLPs\u003c/a\u003e.\n\n```python\nimport torch\nfrom flash_pytorch import GAU\n\ngau = GAU(\n    dim = 512,\n    query_key_dim = 128,     # query / key dimension\n    causal = True,           # autoregressive or not\n    expansion_factor = 2,    # hidden dimension = dim * expansion_factor\n    laplace_attn_fn = True   # new Mega paper claims this is more stable than relu squared as attention function\n)\n\nx = torch.randn(1, 1024, 512)\nout = gau(x) # (1, 1024, 512)\n```\n\nThe authors then combine `GAU` with Katharopoulos linear attention, using grouping of the sequences to overcome a known issue with autoregressive linear attention.\n\nThis combination of the quadratic gated attention unit with grouped linear attention they named FLASH\n\nYou can also use this quite easily\n\n```python\nimport torch\nfrom flash_pytorch import FLASH\n\nflash = FLASH(\n    dim = 512,\n    group_size = 256,             # group size\n    causal = True,                # autoregressive or not\n    query_key_dim = 128,          # query / key dimension\n    expansion_factor = 2.,        # hidden dimension = dim * expansion_factor\n    laplace_attn_fn = True   # new Mega paper claims this is more stable than relu squared as attention function\n)\n\nx = torch.randn(1, 1111, 512)     # sequence will be auto-padded to nearest group size\nout = flash(x) # (1, 1111, 512)\n```\n\nFinally, you can use the full FLASH transformer as mentioned in the paper. This contains all the positional embeddings mentioned in the paper. Absolute positional embedding uses scaled sinusoidal. GAU quadratic attention will get one-headed T5 relative positional bias. On top of all this, both GAU attention as well as the linear attention will be rotary embedded (RoPE).\n\n```python\nimport torch\nfrom flash_pytorch import FLASHTransformer\n\nmodel = FLASHTransformer(\n    num_tokens = 20000,          # number of tokens\n    dim = 512,                   # model dimension\n    depth = 12,                  # depth\n    causal = True,               # autoregressive or not\n    group_size = 256,            # size of the groups\n    query_key_dim = 128,         # dimension of queries / keys\n    expansion_factor = 2.,       # hidden dimension = dim * expansion_factor\n    norm_type = 'scalenorm',     # in the paper, they claimed scalenorm led to faster training at no performance hit. the other option is 'layernorm' (also default)\n    shift_tokens = True          # discovered by an independent researcher in Shenzhen @BlinkDL, this simply shifts half of the feature space forward one step along the sequence dimension - greatly improved convergence even more in my local experiments\n)\n\nx = torch.randint(0, 20000, (1, 1024))\nlogits = model(x) # (1, 1024, 20000)\n```\n\n## Test on Autoregressive Enwik8\n\n```bash\n$ python train.py\n```\n\n## Citations\n\n```bibtex\n@article{Hua2022TransformerQI,\n    title   = {Transformer Quality in Linear Time},\n    author  = {Weizhe Hua and Zihang Dai and Hanxiao Liu and Quoc V. Le},\n    journal = {ArXiv},\n    year    = {2022},\n    volume  = {abs/2202.10447}\n}\n```\n\n```bibtex\n@software{peng_bo_2021_5196578,\n    author    = {PENG Bo},\n    title     = {BlinkDL/RWKV-LM: 0.01},\n    month     = {aug},\n    year      = {2021},\n    publisher = {Zenodo},\n    version   = {0.01},\n    doi       = {10.5281/zenodo.5196578},\n    url       = {https://doi.org/10.5281/zenodo.5196578}\n}\n```\n\n```bibtex\n@inproceedings{Ma2022MegaMA,\n    title   = {Mega: Moving Average Equipped Gated Attention},\n    author  = {Xuezhe Ma and Chunting Zhou and Xiang Kong and Junxian He and Liangke Gui and Graham Neubig and Jonathan May and Luke Zettlemoyer},\n    year    = {2022}\n}\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flucidrains%2FFLASH-pytorch","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Flucidrains%2FFLASH-pytorch","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Flucidrains%2FFLASH-pytorch/lists"}