https://github.com/evanatyourservice/llm-jax
Train a SmolLM-style llm on fineweb-edu in JAX/Flax with an assortment of optimizers.
https://github.com/evanatyourservice/llm-jax
fineweb flax gpt huggingface jax large-language-models llm optimization
Last synced: 3 months ago
JSON representation
Train a SmolLM-style llm on fineweb-edu in JAX/Flax with an assortment of optimizers.
- Host: GitHub
- URL: https://github.com/evanatyourservice/llm-jax
- Owner: evanatyourservice
- License: mit
- Created: 2024-08-26T12:41:47.000Z (11 months ago)
- Default Branch: main
- Last Pushed: 2025-03-17T20:14:49.000Z (4 months ago)
- Last Synced: 2025-04-15T20:13:19.718Z (3 months ago)
- Topics: fineweb, flax, gpt, huggingface, jax, large-language-models, llm, optimization
- Language: Python
- Homepage:
- Size: 3.57 MB
- Stars: 17
- Watchers: 1
- Forks: 1
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# llm-jax
Pretrain a SmolLM-style language model on the fineweb-edu dataset. A 350M param model can reach 51% hellaswag in only 250B tokens by using psgd kron optimizer and architecture improvements.
Has various optimizers: PSGD Kron, adamw, shampoo, CASPR, and schedule-free. Any optimizer can be wrapped in
schedule-free, see configs.py for more details.Only set up for pretraining right now, working on inference, conversion to pytorch, and uploading to huggingface hub.
Saves checkpoints to out_dir, set same experiment name to resume.
Set --profile to profile training to tensorboard, tensorboard dir is /profile.
See configs.py for other settings and all hyperparameters.
This repo is made possible by [Google's TRC program](https://sites.research.google/trc/about/).
Started with [this repo, credit to @jenkspt](https://github.com/jenkspt/gpt-jax). Also pulled some tools
from [big_vision](https://github.com/google-research/big_vision) to add FSDP sharding.Shoutout to @Grad62304977 for sharing model tips to improve training stability.
## Install
Clone llm-jax
```shell
git clone https://github.com/evanatyourservice/llm-jax.git
```Install python dependencies TPU
```shell
cd llm-jax && pip install -U pip && pip install -U -r requirements.txt && pip install --force-reinstall --upgrade --no-cache-dir 'jax[tpu]' 'jaxlib' -f https://storage.googleapis.com/jax-releases/libtpu_releases.html && pip install 'numpy<2'
```Install python dependencies GPU
```shell
cd llm-jax && pip install -U pip && pip install -r requirements.txt && pip install --force-reinstall --upgrade --no-cache-dir 'jax[cuda12]' && pip install 'numpy<2'
```## Run
See examples in /scripts like `scripts/125M_mh_tpu.sh`.
create TPU using queued-resources
```shell
gcloud compute tpus queued-resources create node-4 --node-id node-4 --project distributedmuzerojax --zone us-central2-b --accelerator-type v4-16 --runtime-version tpu-ubuntu2204-base --scopes https://www.googleapis.com/auth/cloud-platform
```