{"id":16269758,"url":"https://github.com/sea-snell/jaxseq","last_synced_at":"2025-05-07T15:22:06.531Z","repository":{"id":61185838,"uuid":"544320728","full_name":"Sea-Snell/JAXSeq","owner":"Sea-Snell","description":"Train very large language models in Jax.","archived":false,"fork":false,"pushed_at":"2023-10-21T19:26:58.000Z","size":258,"stargazers_count":203,"open_issues_count":0,"forks_count":18,"subscribers_count":9,"default_branch":"main","last_synced_at":"2025-03-31T11:21:13.495Z","etag":null,"topics":["deep-learning","flax","gpt2","gpt3","huggingface","jax","language-models","opt"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/Sea-Snell.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null}},"created_at":"2022-10-02T07:11:48.000Z","updated_at":"2025-03-03T14:58:27.000Z","dependencies_parsed_at":"2023-01-20T10:16:13.161Z","dependency_job_id":"8e3060bc-5730-4882-b2e0-ba9cea88cf16","html_url":"https://github.com/Sea-Snell/JAXSeq","commit_stats":{"total_commits":72,"total_committers":2,"mean_commits":36.0,"dds":0.04166666666666663,"last_synced_commit":"aa6d7f33c9adc443c51089aec77b4838a0021585"},"previous_names":[],"tags_count":2,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/Sea-Snell%2FJAXSeq","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/Sea-Snell%2FJAXSeq/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/Sea-Snell%2FJAXSeq/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/Sea-Snell%2FJAXSeq/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/Sea-Snell","download_url":"https://codeload.github.com/Sea-Snell/JAXSeq/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":252902807,"owners_count":21822314,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["deep-learning","flax","gpt2","gpt3","huggingface","jax","language-models","opt"],"created_at":"2024-10-10T18:09:06.406Z","updated_at":"2025-05-07T15:22:06.511Z","avatar_url":"https://github.com/Sea-Snell.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# JaxSeq\n\n**Note: this is version 2.0 of JaxSeq. It supports Jax V0.4, and there are quite few updates that should make it easier to work with. However, if you are dependent on the old version, I would reccoment pulling from the old_version branch, or the version 1.0 commit under github versions.**\n\n## Overview\n\nBuilt on top of [HuggingFace](https://huggingface.co)'s [Transformers](https://github.com/huggingface/transformers) library, JaxSeq enables training very large language models in [Jax](https://jax.readthedocs.io/en/latest/). Currently it supports GPT2, GPTJ, T5, and OPT models. JaxSeq is designed to be light-weight and easily extensible, with the aim being to demonstrate a workflow for training large language models without with the heft that is typical other existing frameworks.\n\nThanks to Jax's [pjit](https://jax.readthedocs.io/en/latest/jax.experimental.pjit.html) function, you can straightforwardly train models with arbitrary model and data parellelism; you can trade-off these two as you like. You can also do model parallelism across multiple hosts. Support for gradient checkpointing, gradient accumulation, and bfloat16 training/inference is provided as well for memory efficient training.\n\n***If you encounter an error or want to contribute, feel free to drop an issue!***\n\n## installation\n\n### **1. pull from github**\n\n``` bash\ngit clone https://github.com/Sea-Snell/JAXSeq.git\ncd JAXSeq\n```\n\n### **2. install dependencies**\n\nInstall with conda (cpu, tpu, or gpu).\n\n**install with conda (cpu):**\n``` shell\nconda env create -f environment.yml\nconda activate JaxSeq\npython -m pip install --upgrade pip\npython -m pip install -e .\n```\n\n**install with conda (gpu):**\n``` shell\nconda env create -f environment.yml\nconda activate JaxSeq\npython -m pip install --upgrade pip\nconda install jaxlib=*=*cuda* jax cuda-nvcc -c conda-forge -c nvidia\npython -m pip install -e .\n```\n\n**install with conda (tpu):**\n``` shell\nconda env create -f environment.yml\nconda activate JaxSeq\npython -m pip install --upgrade pip\npip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html\npython -m pip install -e .\n```\n\n## Workflow\n\nWe provide some example scripts for training and evaluating GPT2, GPTJ, LLaMA, and T5 models using JaxSeq. However you should feel free to build your own workflow for training. You can find these scripts in the `examples/` directory. Each training script takes as input a jsonl file for eval and train data, each of which should be of shape:\n``` json\n{\"in_text\": \"something\", \"out_text\": \"something else\"} \n{\"in_text\": \"something else else\", \"out_text\": \"something else else else\"}\n...\n```\n\nThe examples all use [tyro](https://github.com/brentyi/tyro) to manage commandline args (see their [documentation](https://brentyi.github.io/tyro)).\n\nThis code was largely tested, developed, and optimized for use on TPU-pods, though it should also work well on GPU clusters.\n\n## Google Cloud Buckets\n\nTo further support TPU workflows the example scripts provide functionality for uploading / downloading data and or checkpoints to / from Google Cloud Storage buckets. This can be achieved by prefixing the path with `gcs://`. And depending on the permissions of the bucket, you may need to specify the google cloud project and provide an authentication token.\n\n\n## Other Excellent References for Working with Large Models in Jax\n\n* [EasyLM](https://github.com/young-geng/EasyLM)\n* [maxtext](https://github.com/google/maxtext)\n* [DALL-E Mini Repo](https://t.co/BlM8e66utJ)\n* [Huggingface Model Parallel Jax Demo](https://t.co/eGscnvtNDR)\n* [GPT-J Repo](https://github.com/kingoflolz/mesh-transformer-jax) [uses xmap instead of pjit]\n* [Alpa](https://github.com/alpa-projects/alpa)\n* [Jaxformer](https://github.com/salesforce/jaxformer)\n\n**Many components of this repo came from collaboration with [EasyLM](https://github.com/young-geng/EasyLM).**\n","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fsea-snell%2Fjaxseq","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fsea-snell%2Fjaxseq","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fsea-snell%2Fjaxseq/lists"}