Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/kingoflolz/swarm-jax

Swarm training framework using Haiku + JAX + Ray for layer parallel transformer language models on unreliable, heterogeneous nodes
https://github.com/kingoflolz/swarm-jax

Last synced: 18 days ago
JSON representation

Swarm training framework using Haiku + JAX + Ray for layer parallel transformer language models on unreliable, heterogeneous nodes

Awesome Lists containing this project

README

        

# Pipelined Swarm Training

Swarm training "framework" using Haiku + Jax + Ray.

Designed for training large language models in a model parallel fashion with unreliable, heterogeneous nodes. (eventually)

Look in `swarm_run.py` for an example of running a character transformer on enwik8.

# TODOs

- [x] Forward passes
- [x] Backward passes with activation reconstruction
- [x] Run optimizer
- [x] Logging
- [x] Checkpointing
- [x] Actually do pipelining
- [x] fp16 with static loss scaling
- [x] Integer quantization for activations and gradients between layers
- [ ] Get rid of pipeline stalls from running optimizer
- [ ] Data parallelism with multiple nodes per layer and gradient/weight aggregation
- [ ] Heterogeneous nodes with potentially multiple layers per node
- [ ] Handle unbalanced and unreliable nodes (layerdrop)
- [ ] Dynamic node addition
- [ ] 1T or bust?