{"id":21082862,"url":"https://github.com/kklemon/flashperceiver","last_synced_at":"2025-05-16T09:32:19.374Z","repository":{"id":183315981,"uuid":"669927687","full_name":"kklemon/FlashPerceiver","owner":"kklemon","description":"Fast and memory efficient PyTorch implementation of the Perceiver with FlashAttention.","archived":false,"fork":false,"pushed_at":"2024-11-04T15:34:18.000Z","size":729,"stargazers_count":19,"open_issues_count":0,"forks_count":1,"subscribers_count":1,"default_branch":"master","last_synced_at":"2024-11-04T16:34:31.876Z","etag":null,"topics":["attention-mechanism","deep-learning","flash-attention","nlp","perceiver","transformer"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":null,"status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/kklemon.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":null,"code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2023-07-23T22:07:13.000Z","updated_at":"2024-11-04T15:34:22.000Z","dependencies_parsed_at":null,"dependency_job_id":"d765cf53-002b-4d8c-b954-6d0954b62f0f","html_url":"https://github.com/kklemon/FlashPerceiver","commit_stats":null,"previous_names":["kklemon/fast-perceiver"],"tags_count":0,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/kklemon%2FFlashPerceiver","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/kklemon%2FFlashPerceiver/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/kklemon%2FFlashPerceiver/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/kklemon%2FFlashPerceiver/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/kklemon","download_url":"https://codeload.github.com/kklemon/FlashPerceiver/tar.gz/refs/heads/master","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":225419554,"owners_count":17471434,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["attention-mechanism","deep-learning","flash-attention","nlp","perceiver","transformer"],"created_at":"2024-11-19T20:15:29.460Z","updated_at":"2024-11-19T20:15:30.032Z","avatar_url":"https://github.com/kklemon.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"FlashPerceiver\n=========================\n\nFast and memory efficient PyTorch implementation of the Perceiver [1, 2, 3] architecture with FlashAttention [4, 5] as attention backend.\n\n**Features:**\n\n* :zap: **More than 2x speedup over naive implementation.**\n* :zap: **Sub-linear\u003csup\u003e1\u003c/sup\u003e memory usage with respect to input sequence length and linear usage with respect to number of latent vectors.**\n* :zap: **Out-of-the-box support for rotary positional embeddings [6]**\n* :zap: **Uses the new and improved FlashAttention-2 implementation**\n* :zap: **Support for multiple inputs and flexible masking**\n\n\u003csup\u003e1\u003c/sup\u003e For the attention components. See [Performance](#performance) for more information.\n\nInstallation\n------------\n\n**Note:** The `pyproject.toml` has recently been removed from the flash-attn repository and so did the PEP 517 compliance. This means that the flash-attn cannot be declared as dependency for this project anymore and thus needs to be manually until the situation changes in the future:\n\n```bash\npip install flash-attn --no-build-isolation\n```\n\nAfterwards, install the actual `flash-perceiver` package:\n\n\n```bash\npip install flash-perceiver\n```\n\nUsage\n-----\n\n### Perceiver\n\n![The Perceiver architecture](./figures/perceiver.png)\n\n```python\nimport torch\n\nfrom flash_perceiver import Perceiver, utils\n\nbatch_size, seq_len, in_dim = 32, 128, 256\n\nlatent_dim = 512\nnum_latents = 512\nout_dim = 128\n\nmodel = Perceiver(\n    input_dim=in_dim,\n    depth=8,\n    output_dim=out_dim,\n    num_latents=num_latents,\n    latent_dim=latent_dim,\n    cross_heads=1,\n    cross_head_dim=64,\n    cross_rotary_emb_dim=0,\n    cross_attn_dropout=0.0,\n    latent_heads=8,\n    latent_head_dim=64,\n    latent_rotary_emb_dim=0,\n    latent_attn_dropout=0.0,\n    weight_tie_layers=False,\n    gated_mlp=True,\n    self_per_cross_attn=1,\n    num_zero_tokens=None,\n    use_flash_attn=True,\n).cuda()\n\ndata = torch.randn(batch_size, seq_len, in_dim, device='cuda')\n\n# `out_dim` specified; averages and projects output\n# Note: FlashAttention only supports half-precision.\n#  We need to use `torch.autocast` for the forward-pass\nwith torch.autocast('cuda'):\n    out = model(data)\n\nassert out.shape == (32, out_dim)\n```\n\n**Multiple inputs**\n\nA separate input for each cross-attention block can be used by providing a list of inputs to the `forward` method. The number of inputs must correspond to the `depth` configuration of the model.\n\nBy providing a list of integers to the `input_dim` argument in the constructor, each input can be configured to have a different dimension.\n\n```python\ninput_dims = [256, 512]\n\nmodel = Perceiver(\n    input_dim=input_dims,\n    depth=2,  # must equal len(input_dim)\n).cuda()\n\ninputs = [\n    torch.randn(batch_size, seq_len, in_dim, device='cuda')\n    for in_dim in input_dims\n]\n\nwith torch.autocast('cuda'):\n    out = model(inputs)\n\nassert out.shape == (batch_size, num_latents, latent_dim)\n```\n\n**Masking**\n\nA boolean element-wise mask for the input can be provided. All non-True elements will be masked out within the cross-attention operation. If a list of inputs is provided, a list of masks for each input can be provided as well. This can also include `None` values for inputs without a mask.\n\n```python\nmask = utils.random_mask(data)  # [batch_size, seq_len]\n\nwith torch.autocast('cuda'):\n    out = model(data, mask=mask)\n```\n\n**Extract Embeddings**\n\nIf a value for `output_dim` has been provided to the constructor, the final latent vectors will be averaged and then projected to the desired dimension. To extract the representations prior to the projecting step, set `return_embeddings=True`:\n\n```python\nwith torch.autocast('cuda'):\n    embeds = model(data, return_embeddings=True)\n\nassert embeds.shape == (32, num_latents, latent_dim)\n```\n\n**Custom Latents**\n\nFor some applications it can be useful to have custom sets of latent vectors. For instance, for a multi-task setting, each task could have a separate set of learned latents.\n\nThe `forward` method supports custom latents via the `latents` argument. If not explicitly provided, the module's latent vectors will be used, otherwise the provided ones. These must have shape `[m, latent_dim]` or `[batch_size, n, latent_dim]` where $m$ can be arbitrary.\n\nTo disable initializing random latent vectors as part of the model construction, pass `num_latents=None` to the constructor.\n\n**Extract Attention Weights**\n\n\u003e :warning: This is an experimental feature and requires a modified implementation of FlashAttention [until the changes are eventually merged](https://github.com/Dao-AILab/flash-attention/pull/589).\n\n`return_attn_weights=True` can be passed to the `forward` method of a model to extract the normalized attention weights of each attention layer. A tuple of `(output, attn_weights)` will be returned in this case, where `attn_weights` is a list with one tensor per attention layer. This list follows the pattern `[cross_attn_0, self_attn_0_0, ..., cross_attn_1, self_attn_1_0]` where attention maps for cross-attention layers will have shape `(batch_size, cross_heads, num_latents, seq_len)` and self-attention maps have shape `(batch_size, latent_heads, num_latents, num_latents)`.\n\n```python\nwith torch.autocast('cuda'):\n    out, all_attn_weights = model(data, return_attn_weights=True)\n\nfor i, attn_weights in enumerate(all_attn_weights):\n    if i % model.num_attention_layers_per_block == 0:\n        print('cross-attention map with shape', attn_weights.shape)\n    else:\n        print('self-attention map with shape', attn_weights.shape)\n\n```\n\n\n### PerceiverIO\n\nThe [PerceiverIO](https://arxiv.org/abs/2107.14795) is a variant of the Perceiver architecture where the encoder tower is followed by a decoder module that allows task specific computation of outputs via sets of queries.\n\nThis makes the architecture more flexible and can be used for cases such position specific decoding of values or multi-task settings.\n\n![The PerceiverIO architecture](./figures/perceiver-io.png)\n\n```python\nimport torch\n\nfrom flash_perceiver import PerceiverIO, utils\n\nbatch_size, seq_len, in_dim = 32, 128, 256\n\ndepth = 8\nlatent_dim = 512\nnum_latents = 512\nquery_dim = 128\nnum_queries = 32\nproj_dim = 64\n\nmodel = PerceiverIO(\n    input_dim=in_dim,\n    query_dim=query_dim,\n    depth=depth,\n    proj_dim=proj_dim,\n    num_latents=num_latents,\n    latent_dim=latent_dim,\n    cross_heads=1,\n    cross_head_dim=64,\n    cross_rotary_emb_dim=0,\n    cross_attn_dropout=0.0,\n    latent_heads=8,\n    latent_head_dim=64,\n    latent_rotary_emb_dim=0,\n    latent_attn_dropout=0.0,\n    latent_drop=0.0,\n    query_heads=1,\n    query_head_dim=64,\n    query_rotary_emb_dim=0,\n    query_attn_dropout=0.0,\n    weight_tie_layers=False,\n    gated_mlp=True,\n    use_flash_attn=True,\n).cuda()\n\ndata = torch.randn(batch_size, seq_len, in_dim, device='cuda')\n\n# Can be learned or correspond to positions, tokens, etc.\nqueries = torch.randn(num_queries, query_dim, device='cuda')\n\nwith torch.autocast('cuda'):\n    out = model(data, queries=queries)\n\nassert out.shape == (batch_size, num_queries, proj_dim)\n```\n\nExamples\n--------\n\nOther usage examples are provided in the `examples/` folder.\n\nPerformance\n-----------\n\nThe Perceiver is already designed and intended as an attention architecture with sub-quadratic compute and memory complexity in comparison to the quadratic requirements of a vanilla Transformer.\n\nA naive implementation will have $\\mathcal{O}(nm)$ memory usage for the cross-attention modules and $\\mathcal{O}(n^2)$ complexity for the self-attention or _latent_ blocks, where $n$ the number of input elements , $m$ the number of latent vectors (fixed hyperparameter) and $n \\gg m$ should generally apply.\n\nFlashAttention can reduce the memory usage to $\\mathcal{O}(\\sqrt{nm})$ for the cross-attention layers and $\\mathcal{O}(m)$ for the latent self-attention layers. However, this only accounts for the computation of the attention mechanism. The input sequence and corresponding keys and values within the cross-attention modules will still grow with $n$.\n\nUntil the latter starts to dominate memory usage, this implementation allows to greatly scale the input sequence length. For instance, 16x larger input lengths can be achieved in comparison to [perceiver-pytorch](https://github.com/lucidrains/perceiver-pytorch) on a RTX 4090, keeping the other hyperparameters fixed (see `run_benchmarks.py` for the exact configuration).\n\n### Benchmarks\n\nBenchmarks against other implementations (currently only [perceiver-pytorch]([perceiver-pytorch](https://github.com/lucidrains/perceiver-pytorch)) can be performed with:\n\n```bash\npython run_benchmarks.py\n```\n\nThe script will create a `benchmark_results.csv`. The `create_plots.py` script can then be used to create plots.\n\nThe following data has been obtained with a RTX 4090 and 24GB of VRAM.\n\n![Benchmark results on speedup](figures/benchmark_speedup.png)\n\n![Benchmark results on memory usage reduction](figures/benchmark_memory_usage_reduction.png)\n\n**Note:** The batch size for each configuration corresponds to the smallest value that works for all implementations. Especially for longer sequence lengths, this leads to decreasing GPU utilization and thus a lower speedup than theoretically possible. There are some ways to fix this, but my attempts so far have led to distorted results.\n\nAcknowledgements\n----------------\n\nThe implementation is inspired by lucidrain's [Perceiver implementation](https://github.com/lucidrains/perceiver-pytorch) and would not have been possible without Tri Dao's [FlashAttention](https://github.com/Dao-AILab/flash-attention).\n\nPlanned features\n---------------\n\nThese are a few features that are either planned or WIP. If you have urgent demand for some of them, feel free to write an issue:\n\n- [X] Perceiver IO [2]\n- [ ] Perceiver AR [3] (or an AR demo in general)\n- [X] Demos\n- [X] Tests (see `tests/`)\n- [X] Allow more flexible cross-attention configurations\n- [ ] Benchmarks against other Perceiver implementations, e.g. [DeepMind's](https://github.com/deepmind/deepmind-research/tree/master/perceiver) or [Krasser's](https://github.com/krasserm/perceiver-io)\n- [ ] If FA2 is eventuelly merged into PyTorch, drop the flash-attn dependency\n- [ ] Configure and provide multiple inputs as dict\n- [ ] TensorDict / tensorclass inputs\n- [X] Extract attention weights\n- [ ] Add fancy badges in README\n- [ ] Use custom attention modules for more flexibility\n\nReferences\n----------\n\n[1] Jaegle, Andrew, Felix Gimeno, Andrew Brock, Andrew Zisserman, Oriol Vinyals, and Joao Carreira. “Perceiver: General Perception with Iterative Attention.” arXiv, June 22, 2021. http://arxiv.org/abs/2103.03206.\n\n[2] Jaegle, Andrew, Sebastian Borgeaud, Jean-Baptiste Alayrac, Carl Doersch, Catalin Ionescu, David Ding, Skanda Koppula, et al. “Perceiver IO: A General Architecture for Structured Inputs \u0026 Outputs.” arXiv, March 15, 2022. http://arxiv.org/abs/2107.14795.\n\n[3] Hawthorne, Curtis, Andrew Jaegle, Cătălina Cangea, Sebastian Borgeaud, Charlie Nash, Mateusz Malinowski, Sander Dieleman, et al. “General-Purpose, Long-Context Autoregressive Modeling with Perceiver AR.” arXiv, June 14, 2022. http://arxiv.org/abs/2202.07765.\n\n[4] Dao, Tri, Daniel Y. Fu, Stefano Ermon, Atri Rudra, and Christopher Ré. “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness.” arXiv, June 23, 2022. https://doi.org/10.48550/arXiv.2205.14135.\n\n[5] Dao, Tri. “FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning.” arXiv, July 17, 2023. https://doi.org/10.48550/arXiv.2307.08691.\n\n[6] Su, Jianlin, Yu Lu, Shengfeng Pan, Ahmed Murtadha, Bo Wen, and Yunfeng Liu. “RoFormer: Enhanced Transformer with Rotary Position Embedding.” arXiv, August 8, 2022. https://doi.org/10.48550/arXiv.2104.09864.\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fkklemon%2Fflashperceiver","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fkklemon%2Fflashperceiver","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fkklemon%2Fflashperceiver/lists"}