Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/kingoflolz/swarm-jax
Swarm training framework using Haiku + JAX + Ray for layer parallel transformer language models on unreliable, heterogeneous nodes
https://github.com/kingoflolz/swarm-jax
Last synced: 3 days ago
JSON representation
Swarm training framework using Haiku + JAX + Ray for layer parallel transformer language models on unreliable, heterogeneous nodes
- Host: GitHub
- URL: https://github.com/kingoflolz/swarm-jax
- Owner: kingoflolz
- Created: 2020-12-22T01:15:29.000Z (about 4 years ago)
- Default Branch: master
- Last Pushed: 2023-05-12T08:56:46.000Z (over 1 year ago)
- Last Synced: 2024-05-22T13:31:10.274Z (8 months ago)
- Language: Python
- Homepage:
- Size: 53.7 KB
- Stars: 229
- Watchers: 5
- Forks: 21
- Open Issues: 1
-
Metadata Files:
- Readme: readme.md
Awesome Lists containing this project
- awesome-ray - Swarm-jax - Swarm training framework using Haiku + JAX + Ray for layer parallel transformer language models on unreliable, heterogeneous nodes (Models and Projects / Ray + JAX / TPU)
README
# Pipelined Swarm Training
Swarm training "framework" using Haiku + Jax + Ray.
Designed for training large language models in a model parallel fashion with unreliable, heterogeneous nodes. (eventually)
Look in `swarm_run.py` for an example of running a character transformer on enwik8.
# TODOs
- [x] Forward passes
- [x] Backward passes with activation reconstruction
- [x] Run optimizer
- [x] Logging
- [x] Checkpointing
- [x] Actually do pipelining
- [x] fp16 with static loss scaling
- [x] Integer quantization for activations and gradients between layers
- [ ] Get rid of pipeline stalls from running optimizer
- [ ] Data parallelism with multiple nodes per layer and gradient/weight aggregation
- [ ] Heterogeneous nodes with potentially multiple layers per node
- [ ] Handle unbalanced and unreliable nodes (layerdrop)
- [ ] Dynamic node addition
- [ ] 1T or bust?