{"id":38781232,"url":"https://github.com/atomicarchitects/fusionfail","last_synced_at":"2026-01-17T12:25:00.976Z","repository":{"id":230304810,"uuid":"778502299","full_name":"atomicarchitects/FusionFail","owner":"atomicarchitects","description":"Profile showing 3 layers of NequIP using e3nn-jax","archived":false,"fork":false,"pushed_at":"2024-04-03T18:55:14.000Z","size":1979,"stargazers_count":0,"open_issues_count":0,"forks_count":0,"subscribers_count":4,"default_branch":"main","last_synced_at":"2025-09-09T12:44:46.755Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":null,"status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/atomicarchitects.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":null,"code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null,"notice":null,"maintainers":null,"copyright":null,"agents":null,"dco":null,"cla":null}},"created_at":"2024-03-27T20:49:06.000Z","updated_at":"2024-03-30T01:31:58.000Z","dependencies_parsed_at":"2025-09-09T11:34:42.234Z","dependency_job_id":"cde3db53-f08d-463b-abff-58cb8b067714","html_url":"https://github.com/atomicarchitects/FusionFail","commit_stats":null,"previous_names":["atomicarchitects/fusionfail"],"tags_count":0,"template":false,"template_full_name":null,"purl":"pkg:github/atomicarchitects/FusionFail","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/atomicarchitects%2FFusionFail","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/atomicarchitects%2FFusionFail/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/atomicarchitects%2FFusionFail/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/atomicarchitects%2FFusionFail/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/atomicarchitects","download_url":"https://codeload.github.com/atomicarchitects/FusionFail/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/atomicarchitects%2FFusionFail/sbom","scorecard":null,"host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":286080680,"owners_count":28508464,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2026-01-17T11:50:55.898Z","status":"ssl_error","status_checked_at":"2026-01-17T11:50:55.569Z","response_time":85,"last_error":"SSL_connect returned=1 errno=0 peeraddr=140.82.121.6:443 state=error: unexpected eof while reading","robots_txt_status":"success","robots_txt_updated_at":"2025-07-24T06:49:26.215Z","robots_txt_url":"https://github.com/robots.txt","online":false,"can_crawl_api":true,"host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2026-01-17T12:25:00.867Z","updated_at":"2026-01-17T12:25:00.961Z","avatar_url":"https://github.com/atomicarchitects.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# Fusion Fail\n\n![nequip_profile](images/profile_nequip_3_layer.png)\n\n## What are we looking at ?\n\nThe function `train_step` corresponds to a forward and backward pass through a 3 layered [NequIP](https://www.nature.com/articles/s41467-022-29939-5) model implemented using [e3nn-jax](https://github.com/e3nn/e3nn-jax) acting on a simple Tetris dataset. Thanks @ameya98 @mariogeiger for the code !\n\n## What's happening ?\n\nHere's a brief summary of the under the hood story:\n\n- [XLA](https://github.com/openxla/xla) is unable to pattern match or generate a small subset of fused kernels for the compuatation (See [arxiv:2301.13062](https://arxiv.org/abs/2301.13062) to understand how XLA works). Instead its left with around ~300 kernels (half of which are cuBLAS/CUTLASS calls) that it needs to execute at runtime (small chunks below `Thunk:#hlo_op` in the `TSL` row)\n\n- This makes the compiler fall back to [CUDAGraphs](https://developer.nvidia.com/blog/cuda-graphs/) which batches the execution of these kernels. However, the execution graph needs to be updated with new inputs at runtime (~30% runtime overhead before `Graph 7` is launched on the GPU). This overhead (notice the `CUDA API` row) increases with the size of the computation graph.\n\n## What's the alternative ?\n\nIdeally, the compiler/human should be giving us one forward and one backward fused kernel for our computation (See [FlashAttention](https://arxiv.org/abs/2205.14135)).\n\n### Packages\n\n```bash\npip install requirements.txt\n```\n\nTo reproduce the profile shown above install NVIDIA Nsight Systems and run  the following command (borrowed from [JAX-Toolbox](https://github.com/NVIDIA/JAX-Toolbox/blob/main/docs/profiling.md))\n\n```bash\nnsys profile --capture-range=cudaProfilerApi --cuda-graph-trace=node --capture-range-end=stop -o nequip_profile_disable_cudagraph -f true python train.py\n```\n\n## TODO\n\n- [ ] Add a MLP-equivalent to show what non-CUDAGraph fusion should look like\n- More profiling:\n    - [ ] Add `TensorProduct`, `TensorProductLinear` and `TensorProductLinearGate`\n    - [ ] Allegro-JAX and MACE-JAX\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fatomicarchitects%2Ffusionfail","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fatomicarchitects%2Ffusionfail","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fatomicarchitects%2Ffusionfail/lists"}