{"id":21121770,"url":"https://github.com/chengzeyi/paraattention","last_synced_at":"2025-04-09T11:06:59.959Z","repository":{"id":260597907,"uuid":"879677288","full_name":"chengzeyi/ParaAttention","owner":"chengzeyi","description":"https://wavespeed.ai/ Context parallel attention that accelerates DiT model inference with dynamic caching","archived":false,"fork":false,"pushed_at":"2025-04-02T08:27:09.000Z","size":14047,"stargazers_count":231,"open_issues_count":17,"forks_count":23,"subscribers_count":7,"default_branch":"main","last_synced_at":"2025-04-02T09:31:04.200Z","etag":null,"topics":["attention","diffusers","flux","hunyuan-video","inference","inference-engine","parallel-computing","transformers"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"other","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/chengzeyi.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE.md","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2024-10-28T11:01:12.000Z","updated_at":"2025-04-02T08:26:35.000Z","dependencies_parsed_at":null,"dependency_job_id":"f0463e77-998b-4c62-bcf2-976606860432","html_url":"https://github.com/chengzeyi/ParaAttention","commit_stats":null,"previous_names":["chengzeyi/paraattention"],"tags_count":36,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/chengzeyi%2FParaAttention","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/chengzeyi%2FParaAttention/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/chengzeyi%2FParaAttention/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/chengzeyi%2FParaAttention/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/chengzeyi","download_url":"https://codeload.github.com/chengzeyi/ParaAttention/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":248027407,"owners_count":21035594,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["attention","diffusers","flux","hunyuan-video","inference","inference-engine","parallel-computing","transformers"],"created_at":"2024-11-20T03:57:34.019Z","updated_at":"2025-04-09T11:06:59.933Z","avatar_url":"https://github.com/chengzeyi.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# ParaAttention\n\n[Blazing Fast FLUX-dev with LoRAs](https://wavespeed.ai/models/wavespeed-ai/flux-dev-lora)\n\n[Blazing Fast Wan 2.1 T2V with LoRAs](https://wavespeed.ai/models/wavespeed-ai/wan-2.1/t2v-480p)\n\n[Blazing Fast Wan 2.1 I2V with LoRAs](https://wavespeed.ai/models/wavespeed-ai/wan-2.1/i2v-480p)\n\nContext parallel attention that accelerates DiT model inference with dynamic caching,\nsupporting both [**Ulysses Style**](https://arxiv.org/abs/2309.14509) and [**Ring Style**](https://arxiv.org/abs/2310.01889) parallelism.\n\n[![](https://mermaid.ink/img/pako:eNqNUu9r2zAQ_VcOQUjDbMeWEycxbNAmKxS6kTH2g1X9oFgXW2BLQVbaZMH_-85J6WDsQ_VB0j096U733okVViHLWRiGwhTWbHWZCwM0DsdlJZ1_ifrxrJWvckiSOP4LVqjLyufApwSeXxkMTtpogk5DX2GDwxyGW-uw9cMOusFAmMOx6J8ON-glVNbp39Z4WQvjta8RBPsh6xq8bhDoIpRo0EmvTQnWIOhGlkjF-Apu77_9jJJQ4VMAScwnh34KgM-h9bhrA-LD5-93q7truOdxcLm0lk5ee4-UzRrBqJxQHnQLD4LdyBZrbVCwgKq4vVnKosIrp_z7OIrnoxd4PYfVl_9REj6Cd_C2c9os11d89DbehHiPwhwvlQrWImmlWsEghjD8ACk1X5iNdPDAsyjNqB2zKE5oSaMJfXwWTQmbRAseQBot4kcWsAZdI7Ui8U-9nIKd1RIsp-2GGtG3piOe3Hv79WgKlnu3x4A5uy8rlm9l3VK03ynpcaVl6WTziu6k-WVt8w_ro9LeulewtlIhhSfmj7vehKVuPSW82LDH964muPJ-1-bjcX8clSThfhMVthm3WvU2qp4W2Tjj2VzyFLNZKqdpqopNsphv-STZqlmccMm6LmB4zv_p4viz8bs_BMbpYw?type=png)](https://mermaid.live/edit#pako:eNqNUu9r2zAQ_VcOQUjDbMeWEycxbNAmKxS6kTH2g1X9oFgXW2BLQVbaZMH_-85J6WDsQ_VB0j096U733okVViHLWRiGwhTWbHWZCwM0DsdlJZ1_ifrxrJWvckiSOP4LVqjLyufApwSeXxkMTtpogk5DX2GDwxyGW-uw9cMOusFAmMOx6J8ON-glVNbp39Z4WQvjta8RBPsh6xq8bhDoIpRo0EmvTQnWIOhGlkjF-Apu77_9jJJQ4VMAScwnh34KgM-h9bhrA-LD5-93q7truOdxcLm0lk5ee4-UzRrBqJxQHnQLD4LdyBZrbVCwgKq4vVnKosIrp_z7OIrnoxd4PYfVl_9REj6Cd_C2c9os11d89DbehHiPwhwvlQrWImmlWsEghjD8ACk1X5iNdPDAsyjNqB2zKE5oSaMJfXwWTQmbRAseQBot4kcWsAZdI7Ui8U-9nIKd1RIsp-2GGtG3piOe3Hv79WgKlnu3x4A5uy8rlm9l3VK03ynpcaVl6WTziu6k-WVt8w_ro9LeulewtlIhhSfmj7vehKVuPSW82LDH964muPJ-1-bjcX8clSThfhMVthm3WvU2qp4W2Tjj2VzyFLNZKqdpqopNsphv-STZqlmccMm6LmB4zv_p4viz8bs_BMbpYw)\n\n[![](https://mermaid.ink/img/pako:eNptktuK2zAQhl9lEIS01HZsOXESXxT20NKFtgQKW-hqLxR7YgtsKcjjbbwh795x0m4PVCDQfNJoTv9RFK5EkYswDJUtnN2ZKlcWeB2Gm1p7-mmN67spqc4hSeL4N6zRVDXlIBcMz79MJkdjDaPjlGpscZrDdOc8djQ9wWkyUfYwFOPX4RZJQ-28eXaWdKMsGWoQlPiqmwbItAjsCBVa9JqMrcBZhCdTouNkqIYPvbYD7_sRBZDIVXxYyviQyHUAaQwd4b4L2As-39_d3l3BRxkHF9eN9vqKCDmms0pwUqE-mA4elLjWHTbGohIB5_L--kYX9d8GvIGbzSv5-j9w_gtu_oArho_KDpcQSnTIrS47JSCGMHwL83hsqbJb7eEhzWQWpWkAUi6TKM64riSV0ZozXyarKFkEkM3XkUwfRSBa9K02JU_wOM5EiXPLlcj5uOU6xspO_E735L4MthA5-R4D4V1f1SLf6aZjq9-XmvDW6Mrr9oXutf3mXPvPq3elIedfYON0iWweBQ37UUmV6YgDXrQ08t43jGuifZfPZuN1VPEE-m1UuHbWmXLUQv20zmZc-ErLFLNlqhdpWhbbZL3ayXmyK5dxIrU4nQKB5_ifLrI9q_f0AzkD3HY?type=png)](https://mermaid.live/edit#pako:eNptktuK2zAQhl9lEIS01HZsOXESXxT20NKFtgQKW-hqLxR7YgtsKcjjbbwh795x0m4PVCDQfNJoTv9RFK5EkYswDJUtnN2ZKlcWeB2Gm1p7-mmN67spqc4hSeL4N6zRVDXlIBcMz79MJkdjDaPjlGpscZrDdOc8djQ9wWkyUfYwFOPX4RZJQ-28eXaWdKMsGWoQlPiqmwbItAjsCBVa9JqMrcBZhCdTouNkqIYPvbYD7_sRBZDIVXxYyviQyHUAaQwd4b4L2As-39_d3l3BRxkHF9eN9vqKCDmms0pwUqE-mA4elLjWHTbGohIB5_L--kYX9d8GvIGbzSv5-j9w_gtu_oArho_KDpcQSnTIrS47JSCGMHwL83hsqbJb7eEhzWQWpWkAUi6TKM64riSV0ZozXyarKFkEkM3XkUwfRSBa9K02JU_wOM5EiXPLlcj5uOU6xspO_E735L4MthA5-R4D4V1f1SLf6aZjq9-XmvDW6Mrr9oXutf3mXPvPq3elIedfYON0iWweBQ37UUmV6YgDXrQ08t43jGuifZfPZuN1VPEE-m1UuHbWmXLUQv20zmZc-ErLFLNlqhdpWhbbZL3ayXmyK5dxIrU4nQKB5_ifLrI9q_f0AzkD3HY)\n\n🔥[Fastest FLUX.1-dev Inference with Context Parallelism and First Block Cache on NVIDIA L20 GPUs](./doc/fastest_flux.md)🔥\n\n🔥[Fastest HunyuanVideo Inference with Context Parallelism and First Block Cache on NVIDIA L20 GPUs](./doc/fastest_hunyuan_video.md)🔥\n\nThis aims to provide:\n\n- [x] An easy to use interface to speed up model inference with context parallel, dynamic caching and `torch.compile`. Make **`FLUX`**, **`HunyuanVideo`** and **`Mochi`** inference much faster losslessly.\n- [x] A unified interface to run context parallel attention (***cfg-ulysses-ring***), as well as keeping the maximum performance while working with `torch.compile`\n- [ ] The fastest accurate attention implemented in Triton, running 50% faster than the originial FA2 implementation on RTX 4090.\n\nWhat's different from other implementations:\n\n- No unnecessary graph breaks during `torch.compile`. All the heavy computations are captured in a single graph and get the maximum opportunity to be optimized. This makes it possible for the backend compiler to optimize the graph more effectively, for example, by overlapping the computation and communication.\n- Easy to use. You don't need to change the code of the model to enable context parallelism. Instead, you only need to call a function to parallelize the model.\n- Easy to use, too. If you want to use context parallelism with your custom model, you only need to wrap the call with our special `TorchFunctionMode` context manager.\n- Easy to adjust. You can adjust the parallelism style and the mesh shape with a few lines of code.\n\n# Key Features\n\n### Context Parallelism\n\n**Context Parallelism** (CP) is a method for parallelizing the processing of neural network activations across multiple GPUs by partitioning the input tensors along the sequence dimension.\nUnlike Sequence Parallelism (SP) that partitions the activations of specific layers, CP divides the activations of all layers.\nIn `ParaAttention`, we are able to parallelize the attention layer with a mixture of Ulysses Style and Ring Style parallelism, called Unified Attention.\nThis allows us to achieve the best performance with different models and different hardware configurations.\nWe also provide a unified interface to parallelize the model inference.\n\nYou only need to call a single function to enable context parallelism on your `diffusers` pipeline:\n\n```python\nfrom para_attn.context_parallel.diffusers_adapters import parallelize_pipe\n\nparallelize_pipe(pipe)\n```\n\n### First Block Cache (Our Dynamic Caching)\n\nInspired by [TeaCache](https://github.com/ali-vilab/TeaCache) and other denoising caching algorithms, we introduce **First Block Cache** (FBCache) to use the residual output of the first transformer block as the cache indicator.\nIf the difference between the current and the previous residual output of the first transformer block is small enough, we can reuse the previous final residual output and skip the computation of all the following transformer blocks.\nThis can significantly reduce the computation cost of the model, achieving a speedup of up to 2x while maintaining high accuracy.\n\n#### Optimizations for FLUX Image Generation Model on a Single NVIDIA L20 GPU\n\n| Optimizations | Original | FBCache rdt=0.06 | FBCache rdt=0.08 | FBCache rdt=0.10 | FBCache rdt=0.12 |\n| - | - | - | - | - | - |\n| Preview | ![Original](./assets/flux_original.png) | ![FBCache rdt=0.06](./assets/flux_fbc_0.06.png) | ![FBCache rdt=0.08](./assets/flux_fbc_0.08.png) | ![FBCache rdt=0.10](./assets/flux_fbc_0.10.png) | ![FBCache rdt=0.12](./assets/flux_fbc_0.12.png) |\n| Wall Time (s) | 26.36 | 21.83 | 17.01 | 16.00 | 13.78 |\n\n#### Optimizations for Video Models\n\n| Model | Optimizations | Preview |\n| - | - | - |\n| HunyuanVideo | Original | [Original](https://github.com/user-attachments/assets/883d771a-e74e-4081-aa2a-416985d6c713) |\n| HunyuanVideo | FBCache | [FBCache](https://github.com/user-attachments/assets/f77c2f58-2b59-4dd1-a06a-a36974cb1e40) |\n\nYou only need to call a single function to enable First Block Cache on your `diffusers` pipeline:\n\n```python\nfrom para_attn.first_block_cache.diffusers_adapters import apply_cache_on_pipe\n\napply_cache_on_pipe(\n    pipe,\n    # residual_diff_threshold=0.0,\n)\n```\n\nAdjust the `residual_diff_threshold` to balance the speedup and the accuracy.\nHigher `residual_diff_threshold` will lead to more cache hits and higher speedup, but might also lead to a higher accuracy drop.\n\n# Officially Supported Models\n\n## Context Parallelism with First Block Cache\n\nYou could run the following examples with `torchrun` to enable context parallelism with dynamic caching.\nYou can modify the code to enable `torch.compile` to further accelerate the model inference.\nIf you want quantization, please refer to [diffusers-torchao](https://github.com/sayakpaul/diffusers-torchao) for more information.\nFor example, to run FLUX with 2 GPUs:\n\n**Note**: To measure the performance correctly with `torch.compile`, you need to warm up the model by running it for a few iterations before measuring the performance.\n\n```bash\n# Use --nproc_per_node to specify the number of GPUs\ntorchrun --nproc_per_node=2 parallel_examples/run_flux.py\n```\n\n- [FLUX🚀](parallel_examples/run_flux.py)\n- [HunyuanVideo🚀](parallel_examples/run_hunyuan_video.py)\n- [Mochi](parallel_examples/run_mochi.py)\n- [CogVideoX](parallel_examples/run_cogvideox.py)\n\n## Single GPU Inference with First Block Cache\n\nYou can also run the following examples with a single GPU and enable the First Block Cache to speed up the model inference.\n\n```bash\npython3 first_block_cache_examples/run_hunyuan_video.py\n```\n\n- [FLUX🚀](first_block_cache_examples/run_flux.py)\n- [HunyuanVideo🚀](first_block_cache_examples/run_hunyuan_video.py)\n- [Mochi](first_block_cache_examples/run_mochi.py)\n- [CogVideoX](first_block_cache_examples/run_cogvideox.py)\n\n**NOTE**: To run `HunyuanVideo`, you need to install `diffusers` from its latest master branch.\nIt is suggested to run `HunyuanVideo` with GPUs with at least 48GB memory, or you might experience OOM errors,\nand the performance might be worse due to frequent memory re-allocation.\n\n# Performance\n\n## Context Parallelism (without First Block Cache)\n\n| Model | GPU | Method | Wall Time (s) | Speedup |\n| --- | --- | --- | --- | --- |\n| FLUX.1-dev | A100-SXM4-80GB | Baseline | 13.843 | 1.00x |\n| FLUX.1-dev | A100-SXM4-80GB | `torch.compile` | 9.997 | 1.38x |\n| FLUX.1-dev | A100-SXM4-80GB x 2 | `para-attn (ring)` | 8.307 | 1.66x |\n| FLUX.1-dev | A100-SXM4-80GB x 2 | `para-attn (ring)` + `torch.compile` | 5.775 | 2.39x |\n| FLUX.1-dev | A100-SXM4-80GB x 4 | `para-attn (ulysses + ring)` | 6.157 | 2.25x |\n| FLUX.1-dev | A100-SXM4-80GB x 4 | `para-attn (ulysses + ring)` + `torch.compile` | 3.557 | 3.89x |\n| mochi-1-preview | A100-SXM4-80GB | Baseline | 196.534 | 1.00x |\n| mochi-1-preview | A100-SXM4-80GB | `torch.compile` | 149.868 | 1.31x |\n| mochi-1-preview | A100-SXM4-80GB x 2 | `para-attn (cfg)` | 105.438 | 1.86x |\n| mochi-1-preview | A100-SXM4-80GB x 2 | `para-attn (ulysses)` | 110.146 | 1.78x |\n| mochi-1-preview | A100-SXM4-80GB x 2 | `para-attn (ring)` | 109.435 | 1.80x |\n| mochi-1-preview | A100-SXM4-80GB x 2 | `para-attn (cfg)` + `torch.compile` | 81.913 | 2.40x |\n| mochi-1-preview | A100-SXM4-80GB x 2 | `para-attn (ulysses)` + `torch.compile` | 83.912 | 2.34x |\n| mochi-1-preview | A100-SXM4-80GB x 2 | `para-attn (ring)` + `torch.compile` | 82.176 | 2.39x |\n| mochi-1-preview | A100-SXM4-80GB x 4 | `para-attn (cfg + ring)` | 61.206 | 3.21x |\n| mochi-1-preview | A100-SXM4-80GB x 4 | `para-attn (cfg + ring)` + `torch.compile` | 47.100 | 4.17x |\n\n**NOTE**: The speedup of iterations per second is generally higher than the speedup of wall time, because the wall time includes the overhead of calling the text encoder and vae decoder.\n\n# Installation\n\n## Install from PyPI\n\n```bash\npip3 install 'torch==2.6.0'\npip3 install para-attn\n```\n\n## Local Installation\n\n```bash\ngit clone https://github.com/chengzeyi/ParaAttention.git\ncd ParaAttention\ngit submodule update --init --recursive\n\npip3 install 'torch==2.6.0'\npip3 install 'setuptools\u003e=64' 'setuptools_scm\u003e=8'\n\n# Pass --no-use-pep517 to speed up rebuild by using the legacy build system\n# which doesn't use a one-time tmp directory for the build\npip3 install -e '.[dev]' --no-build-isolation\n# Or:\n# python3 setup.py develop\n\n# Code formatting and linting\npip3 install pre-commit\npre-commit install\npre-commit run --all-files\n```\n\n# Usage\n\n## All Examples\n\nPlease refer to examples in the `parallel_examples` and `first_block_cache_examples` directories.\n\n### Parallelize Models\n\n| Model | Command |\n| - | - |\n| `FLUX` | `torchrun --nproc_per_node=2 parallel_examples/run_flux.py` |\n| `HunyuanVideo` | `torchrun --nproc_per_node=2 parallel_examples/run_hunyuan_video.py` |\n| `Mochi` | `torchrun --nproc_per_node=2 parallel_examples/run_mochi.py` |\n| `CogVideoX` | `torchrun --nproc_per_node=2 parallel_examples/run_cogvideox.py` |\n\n### Apply First Block Cache\n\n| Model | Command |\n| - | - |\n| `FLUX` | `python3 first_block_cache_examples/run_flux.py` |\n| `HunyuanVideo` | `python3 first_block_cache_examples/run_hunyuan_video.py` |\n| `Mochi` | `python3 first_block_cache_examples/run_mochi.py` |\n| `CogVideoX` | `python3 first_block_cache_examples/run_cogvideox.py` |\n\n## Parallelize VAE\n\nVAE can be parallelized with `para_attn.parallel_vae.diffusers_adapters.parallelize_vae`.\nCurrently, only `AutoencoderKL` and `AutoencoderKLHunyuanVideo` are supported.\n\n``` python\nimport torch\nimport torch.distributed as dist\nfrom diffusers import AutoencoderKL\n\ndist.init_process_group()\n\ntorch.cuda.set_device(dist.get_rank())\n\nvae = AutoencoderKL.from_pretrained(\n    \"black-forest-labs/FLUX.1-dev\",\n    torch_dtype=torch.bfloat16,\n).to(\"cuda\")\n\nfrom para_attn.parallel_vae.diffusers_adapters import parallelize_vae\n\nparallelize_vae(vae)\n```\n\n## Run Unified Attention (Hybird Ulysses Style and Ring Style) with `torch.compile`\n\n```python\nimport torch\nimport torch.distributed as dist\nimport torch.nn.functional as F\nfrom para_attn import para_attn_interface\n\ndist.init_process_group()\nworld_size = dist.get_world_size()\nrank = dist.get_rank()\n\nassert world_size \u003c= torch.cuda.device_count()\nif world_size % 2 == 0:\n    mesh_shape = (2, world_size // 2)\nelse:\n    mesh_shape = (1, world_size)\n\nB, H, S_Q, S_KV, D = 2, 24, 4096, 4096, 64\ndtype = torch.float16\ndevice = \"cuda\"\n\ndef func(*args, **kwargs):\n    return F.scaled_dot_product_attention(*args, **kwargs)\n\n# torch._inductor.config.reorder_for_compute_comm_overlap = True\n# func = torch.compile(func)\n\nwith torch.no_grad(), torch.cuda.device(rank):\n    torch.manual_seed(0)\n\n    query = torch.randn(B, H, S_Q, D, dtype=dtype, device=device)\n    key = torch.randn(B, H, S_KV, D, dtype=dtype, device=device)\n    value = torch.randn(B, H, S_KV, D, dtype=dtype, device=device)\n    attn_mask = None\n    dropout_p = 0.0\n    is_causal = False\n\n    query_slice = query.chunk(world_size, dim=-2)[rank]\n    key_slice = key.chunk(world_size, dim=-2)[rank]\n    value_slice = value.chunk(world_size, dim=-2)[rank]\n\n    for _ in range(2):\n        mesh = dist.init_device_mesh(device, mesh_shape, mesh_dim_names=(\"ring\", \"ulysses\"))\n        with para_attn_interface.UnifiedAttnMode(mesh):\n            out_slice = func(\n                query_slice,\n                key_slice,\n                value_slice,\n                attn_mask=attn_mask,\n                dropout_p=dropout_p,\n                is_causal=is_causal,\n            )\n\n    out_slice_ref = F.scaled_dot_product_attention(\n        query,\n        key,\n        value,\n        attn_mask=attn_mask,\n        dropout_p=dropout_p,\n        is_causal=is_causal,\n    ).chunk(world_size, dim=-2)[rank]\n\n    torch.testing.assert_close(out_slice, out_slice_ref, rtol=1e-5, atol=1e-3 * world_size)\n\ndist.destroy_process_group()\n```\n\nSave the above code to `test.py` and run it with `torchrun`:\n\n```bash\ntorchrun --nproc_per_node=2 test.py\n```\n\n# Run Tests\n\n```bash\npytest tests --html=report.html --self-contained-html\n```\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fchengzeyi%2Fparaattention","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fchengzeyi%2Fparaattention","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fchengzeyi%2Fparaattention/lists"}