https://github.com/kyegomez/tree-attention-torch
An implementation of Tree-Attention in PyTorch because it's in JAX for some reason
https://github.com/kyegomez/tree-attention-torch
ai attention dao deep-learning distributed gpu llm machine-learning ml nvidia parallel research stanford tri
Last synced: 3 months ago
JSON representation
An implementation of Tree-Attention in PyTorch because it's in JAX for some reason
- Host: GitHub
- URL: https://github.com/kyegomez/tree-attention-torch
- Owner: kyegomez
- License: mit
- Created: 2024-09-09T19:25:26.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2025-09-08T04:15:02.000Z (9 months ago)
- Last Synced: 2025-09-09T03:58:43.184Z (9 months ago)
- Topics: ai, attention, dao, deep-learning, distributed, gpu, llm, machine-learning, ml, nvidia, parallel, research, stanford, tri
- Language: Python
- Homepage: https://discord.com/servers/agora-999382051935506503
- Size: 2.18 MB
- Stars: 4
- Watchers: 1
- Forks: 0
- Open Issues: 7
-
Metadata Files:
- Readme: README.md
- Funding: .github/FUNDING.yml
- License: LICENSE
Awesome Lists containing this project
README
[](https://discord.com/servers/agora-999382051935506503)
# Tree Attention Torch
An implementation of Tree-Attention in PyTorch because it's in JAX for some reason
[](https://discord.gg/agora-999382051935506503) [](https://www.youtube.com/@kyegomez3242) [](https://www.linkedin.com/in/kye-g-38759a207/) [](https://x.com/kyegomezb)
## Usage
```bash
python3 model.py
```
# License
MIT
# Todo
- [ ] Implement flash attention from the native official repo, I couldn't because the docs are nowhere to be found and understood