https://github.com/0x7o/ae
Scalable code for training and fine-tuning language models on TPUs
https://github.com/0x7o/ae
large-language-models scaling tpu
Last synced: 3 months ago
JSON representation
Scalable code for training and fine-tuning language models on TPUs
- Host: GitHub
- URL: https://github.com/0x7o/ae
- Owner: 0x7o
- License: mit
- Created: 2024-02-25T14:12:18.000Z (over 1 year ago)
- Default Branch: master
- Last Pushed: 2025-02-04T10:59:29.000Z (4 months ago)
- Last Synced: 2025-02-04T11:26:42.343Z (4 months ago)
- Topics: large-language-models, scaling, tpu
- Language: Python
- Homepage:
- Size: 109 KB
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README

# (WIP) æ
Code base for training GPT-like models on TPUs with support for parallelization and scaling on JAX.## To Do
- [x] Data parallelization on devices with `jax.sharding`
- [x] Support for bfloat16 during training
- [ ] Model parallelization with `jax.pjit` and `Mesh`
- [ ] Flash Attention support## Special Thanks
- [Phil Wang](https://github.com/lucidrains) for the [PaLM-jax](https://github.com/lucidrains/PaLM-jax)
- [Hugging Face](https://huggingface.co/) for the [transformers](https://github.com/huggingface/transformers)## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.