{"id":17272646,"url":"https://github.com/codertimo/jax-lm-training","last_synced_at":"2026-03-07T16:31:39.018Z","repository":{"id":40384616,"uuid":"488889200","full_name":"codertimo/jax-lm-training","owner":"codertimo","description":"generative language model training on top of the JAX and Huggingface 🤗","archived":false,"fork":false,"pushed_at":"2022-05-11T15:46:40.000Z","size":963,"stargazers_count":11,"open_issues_count":0,"forks_count":1,"subscribers_count":2,"default_branch":"main","last_synced_at":"2025-04-14T08:21:44.774Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":null,"language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":null,"status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/codertimo.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":null,"code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2022-05-05T08:25:07.000Z","updated_at":"2024-11-02T15:54:28.000Z","dependencies_parsed_at":"2022-07-14T00:50:28.147Z","dependency_job_id":null,"html_url":"https://github.com/codertimo/jax-lm-training","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":"codertimo/python-template","purl":"pkg:github/codertimo/jax-lm-training","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/codertimo%2Fjax-lm-training","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/codertimo%2Fjax-lm-training/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/codertimo%2Fjax-lm-training/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/codertimo%2Fjax-lm-training/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/codertimo","download_url":"https://codeload.github.com/codertimo/jax-lm-training/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/codertimo%2Fjax-lm-training/sbom","scorecard":null,"host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":286080680,"owners_count":30221507,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2026-03-07T14:02:48.375Z","status":"ssl_error","status_checked_at":"2026-03-07T14:02:43.192Z","response_time":53,"last_error":"SSL_read: unexpected eof while reading","robots_txt_status":"success","robots_txt_updated_at":"2025-07-24T06:49:26.215Z","robots_txt_url":"https://github.com/robots.txt","online":false,"can_crawl_api":true,"host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2024-10-15T08:49:06.601Z","updated_at":"2026-03-07T16:31:38.981Z","avatar_url":"https://github.com/codertimo.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# JAX Language Model Training\n\n\u003e Currently this repo is under construction :) \n\n## todos\n\n- [X] writing corpus tokenizing and featurizing code with Apache Beam (done at 5/6)\n- [X] writing training code with single GPU\n- [X] writing evaluation code with single GPU\n- [X] writing metric tracking code with weight\u0026bias\n- [X] train LM with single GPU and debug!\n- [X] writing parallelism code for multi-GPUs support.\n- [X] train LM with multi-GPUs and debug!\n- [ ] training in TPU and TPU-Pod\n- [ ] result comparison GPU, GPUs, TPU, TPU-Pod\n- [ ] writing similar code with pytorch and compare the training performance\n\n\n## step1. Pre-Featurize Corpus with Apache Beam\n\n`jax_lm.preprocess` script download the corpus and process to trainable model input. \nI wrote the Apache Beam pipeline to process corpus and export the output with Apache Parquet.\n\n[Apache Beam](https://beam.apache.org/get-started/beam-overview/) is an open source, unified model for defining both batch and streaming data-parallel processing pipelines.\nAnd [Google Dataflow](https://cloud.google.com/dataflow?hl=ko) automatically parallelize Apache Beam Pipeline which could accelerate huge corpus processing time with blazing speed!\n(Well, Dataflow processing code is not implemented in this repo yet. I'll do it soon! )\n\n[Apache Parquet](https://parquet.apache.org) is an open source, column-oriented data file format designed for efficient data storage and retrieval. \nIt provides efficient data compression and encoding schemes with enhanced performance to handle complex data in bulk.\n[huggingface/dataset](https://huggingface.co/docs/datasets/index) can directly load Parquet file!\n So we just need to used huggingface/dataset when we are training :)\n\nThe detail of procedures are described as follow.\n\n1. Download the text corpus from huggingface/dataset \n2. Convert the huggingface/dataset to Apache Beam PCollection\n3. Tokenize using huggingface/tokenizer with Apache Beam DoFn\n4. Chunk tokens into number of input_ids with Apache Beam DoFn\n5. Add BOS, EOS, PAD token and make it to model inputs(feature) with Apache Beam DoFn\n6. Write the featurized PCollection into Apache Parquet file\n7. Now ready to train model!\n\n```shell\npython -m jax_lm.preprocess \\\n    --tokenizer-model \"gpt2\" \\\n    --min-sequence-length 128 \\\n    --max-sequence-length 256 \\\n    --num-special-token-reserved 2 \\\n    --ignore-label -100 \\\n    --stride 128 \\\n    --dataset-name \"wikitext\" \\\n    --dataset-sub-name \"wikitext-103-v1\" \\\n    --dataset-split-type \"train\" \\\n    --output-path \"dataset/wikitext.train\" \\\n    --direct_running_mode \"multi_processing\" \\\n    --direct_num_workers 0\n```\n\n```python\nfrom datasets import Dataset\n\ndataset = Dataset.from_parquet(\"data/wikitext.train*\")\n```\n\n## step2. Train model\n\n`jax_lm.train` script run model training with multi-gpu support!\n\n```shell\npython -m jax_lm.train \\\n    --model-config-name \"gpt2\" \\\n    --train-dataset-paths \"dataset/wikitext.train**\" \\\n    --eval-dataset-paths \"dataset/wikitext.test**\" \\\n    --batch-size 16 \\\n    --random-seed 0 \\\n    --max-sequence-length 256 \\\n    --num-epochs 5 \\\n    --learning-rate 3e-5 \\\n    --dtype float32 \\\n    --wandb-username \"codertimo\" \\\n    --wandb-project \"jax-lm-training\" \\\n    --wandb-run-dir \".wandb\" \\\n    --logging-frequency 100 \\\n    --eval-frequency 5000 \\\n    --save-frequency 5000 \\\n    --model-save-dir \"artifacts/\"\n```\n\n**Training Logs**\n\n```\n[TRAIN] epoch: 0 step: 100/161045 loss: 8.8082 ppl: 6688.58 ETA 23:14:51 \n[TRAIN] epoch: 0 step: 200/161045 loss: 5.8967 ppl: 363.83 ETA 09:49:45 \n[TRAIN] epoch: 0 step: 300/161045 loss: 5.0164 ppl: 150.87 ETA 09:49:23 \n...\n[TRAIN] epoch: 4 step: 160800/161045 loss: 3.2986 ppl: 27.07 ETA 00:00:53 \n[TRAIN] epoch: 4 step: 160900/161045 loss: 3.0941 ppl: 22.07 ETA 00:00:31 \n[TRAIN] epoch: 4 step: 161000/161045 loss: 3.0417 ppl: 20.94 ETA 00:00:09 \n[EVAL] epoch: 4 step: 161044/161045 loss: 2.7272 ppl: 15.29 \nsave checkpoint to artifacts/ \n```\n\n## step 3. Watching W\u0026B and chill\n\n[codertimo/jax-lm-training W\u0026B Board](https://wandb.ai/codertimo/jax-lm-training?workspace=user-codertimo)\n\n![wandb image](docs/wandb.png)\n\nYou can see the training logs with beautiful graph and logging at W\u0026B!\n\n\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fcodertimo%2Fjax-lm-training","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fcodertimo%2Fjax-lm-training","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fcodertimo%2Fjax-lm-training/lists"}