https://github.com/tensor-fusion/sophia-jax
JAX implementation of 'Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training'
https://github.com/tensor-fusion/sophia-jax
deep-learning jax large-language-models llm machine-learning optimization optimizers sophia
Last synced: 11 months ago
JSON representation
JAX implementation of 'Sophia: A Scalable Stochastic Second-order Optimizer for Language Model Pre-training'
- Host: GitHub
- URL: https://github.com/tensor-fusion/sophia-jax
- Owner: tensor-fusion
- Created: 2024-05-23T15:14:09.000Z (about 2 years ago)
- Default Branch: main
- Last Pushed: 2024-05-23T21:53:01.000Z (about 2 years ago)
- Last Synced: 2025-02-28T00:33:03.883Z (over 1 year ago)
- Topics: deep-learning, jax, large-language-models, llm, machine-learning, optimization, optimizers, sophia
- Language: Python
- Homepage:
- Size: 259 KB
- Stars: 2
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# Sophia - JAX

JAX implementation of the [Sophia optimizer](https://arxiv.org/abs/2305.14342) for LLM pre-training. Official PyTorch implementation is here: https://github.com/Liuhong99/Sophia
In the paper, Sophia is reported to be 2x faster than Adam on GPT-2.
In the wild it's recently been battle-tested on large-scale runs at Meta and a similar speed-up was observed as well: https://x.com/ArmenAgha/status/1780149168692158658
## TODO
- [ ] Reproduce pretraining results with GPT models
- [ ] Comparisons to AdamW, LION, etc.
- [ ] etc