{"id":29202570,"url":"https://github.com/ai-hypercomputer/torchprime","last_synced_at":"2025-07-02T13:32:41.659Z","repository":{"id":273614831,"uuid":"869815881","full_name":"AI-Hypercomputer/torchprime","owner":"AI-Hypercomputer","description":"torchprime is a reference model implementation for PyTorch on TPU.","archived":false,"fork":false,"pushed_at":"2025-06-30T21:10:30.000Z","size":3809,"stargazers_count":28,"open_issues_count":91,"forks_count":5,"subscribers_count":7,"default_branch":"main","last_synced_at":"2025-06-30T22:24:02.995Z","etag":null,"topics":[],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"other","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/AI-Hypercomputer.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null}},"created_at":"2024-10-09T00:02:42.000Z","updated_at":"2025-06-27T17:15:47.000Z","dependencies_parsed_at":"2025-02-21T01:24:16.764Z","dependency_job_id":"88c0835b-6d1c-49ec-88df-05e72ef59427","html_url":"https://github.com/AI-Hypercomputer/torchprime","commit_stats":null,"previous_names":["ai-hypercomputer/torchprime"],"tags_count":0,"template":false,"template_full_name":null,"purl":"pkg:github/AI-Hypercomputer/torchprime","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/AI-Hypercomputer%2Ftorchprime","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/AI-Hypercomputer%2Ftorchprime/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/AI-Hypercomputer%2Ftorchprime/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/AI-Hypercomputer%2Ftorchprime/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/AI-Hypercomputer","download_url":"https://codeload.github.com/AI-Hypercomputer/torchprime/tar.gz/refs/heads/main","sbom_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/AI-Hypercomputer%2Ftorchprime/sbom","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":263148125,"owners_count":23421116,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":[],"created_at":"2025-07-02T13:30:28.961Z","updated_at":"2025-07-02T13:32:41.638Z","avatar_url":"https://github.com/AI-Hypercomputer.png","language":"Python","readme":"\u003cdiv align=\"center\"\u003e\n\n# torchprime\n\n#### High-performance training for PyTorch on Cloud TPU\n\n\u003c/div\u003e\n\u003cbr /\u003e\u003cbr /\u003e\n\n`torchprime` is a reference implementation for training PyTorch models on TPU. It\nis designed to showcase best practices for large-scale, high-performance model\ntraining using `torch_xla` ([project][torch_xla]), with\nminimal changes to model code. It aims to demystify training on XLA-based\naccelerators, providing clear patterns and best practices to help the PyTorch\ncommunity unlock top performance and efficiency on Google Cloud TPUs.\n\n`torchprime` is under active development, and we're eager for feedback and input\nfrom the PyTorch community.\n\n## Environment setup\n\nFor development and debugging purposes it is useful to run `torchprime`\nlocally on a TPU VM. You can create a single-host TPU VM using\nthis guide: https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm\nOr you can create TPU VM from the Google Cloud Console for your cloud project.\nSpot quota is available for v5e and v6e chips in multiple regions:\nhttps://cloud.google.com/tpu/docs/regions-zones\n\nMake sure that you are using the correct runtime when creating\nyour VM: https://cloud.google.com/tpu/docs/runtimes#pytorch_and_jax\n\nFor example:\n\n```sh\ngcloud compute tpus tpu-vm create \u003ctpu-name\u003e \\\n  --zone=us-central1-a \\\n  --accelerator-type=v6e-4 \\\n  --version=v2-alpha-tpuv6e \\\n  --spot\n```\n\nOnce the VM is created you can `ssh` into it:\nhttps://cloud.google.com/tpu/docs/managing-tpus-tpu-vm#tpu-connect\n\n```\ngcloud compute tpus tpu-vm ssh \u003ctpu-name\u003e --zone=\u003czone\u003e\n```\n\n### Install `torch_xla`\n\nBefore installing `torchprime`, you will need to first install\n[torch_xla][torch_xla] following its respective project README.\nYou need to install nightly version of\nPyTorch/XLA.\n\nTest that the environment is correctly installed and configured.\nStart `python` interpreter and run following:\n\n```python\nimport torch_xla.core.xla_model as xm\nprint(\"XLA devices:\", xm.get_xla_supported_devices())\nprint(\"Default XLA device type:\", xm.xla_device_hw(xm.xla_device()))\n```\n\n### Install `torchprime`\n\nMake sure that `pip` and `setuptools` are up-to-date:\n\n```sh\npython -m pip install --upgrade pip\npython -m pip install --upgrade setuptools==69.5.1\n```\n\n```sh\ngit clone https://github.com/AI-Hypercomputer/torchprime.git\ncd torchprime\npip install -e '.[dev]'\n```\n\n## Examples\n\n### Local training\n\nHere is a simple example of training on a single TPU VM with 4 TPU chips.\nTrain Llama 3 8B using `torch_xla`:\n\n```sh\nexport HF_TOKEN='...your huggingface token...'\nXLA_IR_DEBUG=1 XLA_HLO_DEBUG=1 python3 torchprime/torch_xla_models/train.py\n```\n\nBy default, this will distribute the model over 4 TPU chips attached to the VM\nusing the [FSDP parallelism strategy][fsdp].\n\nThe first two training steps will take a while to compile. After that, the graphs\nwill hit the compilation cache and you should see something like this:\n\n```\n...\n[2025-04-29 06:58:36,445][__main__][INFO] - Num replicas: 1\n[2025-04-29 06:58:36,447][__main__][INFO] - Starting training\n[2025-04-29 06:58:36,448][__main__][INFO] -     Max step: 15\n[2025-04-29 06:58:36,448][__main__][INFO] -     Global batch size: 4\n[2025-04-29 07:01:16,240][__main__][INFO] - Epoch: 0, step: 0, loss: 12.5574, trace time: 155003.85 ms\n[2025-04-29 07:04:24,182][__main__][INFO] - Epoch: 0, step: 10, loss: 9.7555, trace time: 1564.54 ms\n...\n```\n\nRefer to `README.md` in `torchprime/torch_xla_models` for more details.\n\n### Configuring training\n\n`torchprime` uses [hydra][hydra] to read configurations (e.g. model name, batch\nsize) from the command line and `.yaml` files.\n\nIn the `torch_xla_models` directory, you'll find a `configs/default.yaml`. That\nspecifies the default configuration for the trainer. You may override configs on\nthe command line with a `key=value` syntax. For example, the following command\nwill train Mixtral 8x7B with a global batch size of 256, and set the FSDP SPMD\nICI mesh axis length to 64:\n\n```sh\npython3 torchprime/torch_xla_models/train.py \\\n    model=mixtral-8x7b \\\n    task=train \\\n    dataset=wikitext \\\n    task.global_batch_size=256 \\\n    ici_mesh.fsdp=64\n```\n\nYou may refer to the hydra docs for other ways to specify configs.\n\nTo fine-tune a pretrained model using the gsm8k (Grade School Math question-answer) dataset, run\n\n```sh\npython3 torchprime/torch_xla_models/train.py --config-name llama-3-8b-sft-w-gsm8k\n```\n\nThis uses the `llama-3-8b-sft-w-gsm8k.yaml` config which selects the SFT trainer and\ndataset automatically.\n\n### Multi-VM distributed training\n\n`torchprime` uses [xpk][xpk] as the standard path for iterating on distributed\ntraining code that needs to run on multiple VMs.\n\nFirst teach `torchprime` about the XPK cluster it is using, the artifact storage\nlocation, the Google Cloud project/zone, and the TPU topology. You only need\nto do this on first clone or when switching to a different topology or cluster.\nExample:\n\n```sh\ntp use \\\n    --cluster \u003cXPK CLUSTER NAME\u003e \\\n    --project \u003cmy-gcp-project\u003e \\\n    --docker-project \u003cmy-docker-project-if-it-is-different-from-gcp-project\u003e \\\n    --zone us-east5-b \\\n    --num-slices 1 \\\n    --tpu-type v6e-256 \\\n    --artifact-dir \u003cmy-gs-bucket-dir\u003e\n```\n\n`torchprime` natively supports [multi-slice or multi-pod][multi-slice] training.\n`--num-slices` specifies the number of [slices][tpu-slice] used by the workload.\n`--tpu-type` specifies the [accelerator type][accelerator-type] in each slice.\nTo do multi-pod training, simply specify a `--tpu-type` that is as big as a\n[pod][tpu-pod].\n\nAfter configuring the cluster, prepend `tp run` to a particular Python file you\nwould like to run remotely, including arguments, e.g.\n\n```sh\n# Train Llama 3.0 8B on 256 chips\ntp run torchprime/torch_xla_models/train.py \\\n    model=llama-3-8b \\\n    task.global_batch_size=256 \\\n    ici_mesh.fsdp=256\n```\n\n`tp run` will broadcast the specified command to all VMs in the XPK cluster,\nwhich is the convention for running SPMD distributed workloads. See `tp run\n--help` for more advanced features.\n\nThe version of `torch_xla` used by `tp run` is specified in `pyproject.toml`.\n\n#### Env vars passed to the workload\n\n`tp run` will pick up these environment variables locally and proxy them to the\ndistributed workload, if found:\n\n- `HF_TOKEN`: HuggingFace token\n- `XLA_IR_DEBUG`: [torch_xla debugging flag][torch_xla_debug_env]\n- `XLA_HLO_DEBUG`: [torch_xla debugging flag][torch_xla_debug_env]\n- `LIBTPU_INIT_ARGS`: XLA flags that affect compilation and execution behavior\n\n#### Additional CLI arguments passed to the workload\n\nBesides forwarding your command line arguments, `tp run` will add:\n\n- `profile_dir=[...]`: path to a [profile][torch_xla_profile] directory\n  accessible by the workload\n- `output_dir=[...]` path to a directory where the workload may store output\n  artifacts such as metrics and checkpoints\n\n#### Troubleshooting distributed training setup\n\nSee the guide to [troubleshoot distributed training][troubleshoot-distributed]\nfor troubleshooting tools and an FAQ. \n\n## Supported Models\n\ntorchprime has implementations for the following models:\n\n- [Llama 3.0 8B](torchprime/torch_xla_models/README.md#llama-30-8b-on-v6e-256)\n- [Llama 3.1 8B](torchprime/torch_xla_models/README.md#llama-31-8b-on-v6e-256)\n- [Llama 3.1 70B](torchprime/torch_xla_models/README.md#llama-31-70b-on-v6e-256)\n- [Llama 3.1 405B](torchprime/torch_xla_models/README.md#llama-31-405b-on-v6e-256)\n- [Mixtral 8x7B](torchprime/torch_xla_models/README.md#mixtral-8x7b-on-v6e-256)\n\nAll implemented models will have a training recipe, and are backed by unit tests.\n\nInterested in another model? File an [issue](https://github.com/AI-Hypercomputer/torchprime/issues).\n\n## Structure\n\nThis repo will contain a set of reference models that we have optimized and runs\nwell on TPU. The best performing scaling configuration (parallelism techniques,\ncheckpointing, etc.) for a model on various hardwares will be provided for ease\nof reproducibility.\n\n`docs` contains guides for optimizing performance and debugging issues.\n\n`torchprime/launcher` contains scripts to train a model on a large TPU cluster.\n\n`torchprime/data` contains dataset and data loading utilities.\n\n`torchprime/torch_xla_models` contains model implementations using `torch_xla`.\n\n`torchprime/experimental` contains experimental model implementations\n\nFinally, each model may also provide a GPU \"original\" version that illustrates\nand attributes where this model code came from, if any. This also helps to\nshowcase what changes we have done to make it performant on TPU. The original\nversion is not expected to be run.\n\n## Contributing\n\nContributions are welcome! Please feel free to submit a pull request.\n\nRefer to the [contributor guide](./docs/contributor/README.md) to get started.\n\n## Questions and suggestions\n\nFeel free to ask questions in the [Discussions][discussions] panel, or to look\nat previous questions and discussions.\n\n## License\n\nThis project is licensed under the New BSD License - see the [LICENSE](LICENSE)\nfile for details.\n\nFor more information on PyTorch/XLA, visit the [official\ndocumentation](https://github.com/pytorch/xla).\n\n[torch_xla]: https://github.com/pytorch/xla\n[fsdp]: https://jax-ml.github.io/scaling-book/training/#fully-sharded-data-parallelism-fsdp\n[discussions]: https://github.com/AI-Hypercomputer/torchprime/discussions/categories/q-a\n[xpk]: https://github.com/AI-Hypercomputer/xpk\n[torch_xla_debug_env]:\n    https://github.com/pytorch/xla/blob/master/docs/source/learn/troubleshoot.md#environment-variables\n[torch_xla_profile]:\n    https://github.com/pytorch/xla/blob/master/docs/source/learn/troubleshoot.md#performance-profiling\n[hydra]: https://hydra.cc/docs/intro/\n[torch_xla_dev_docker]:\n    https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md#manually-build-in-docker-container\n[tpu-pod]: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#tpu-pod\n[tpu-slice]: https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#slices\n[accelerator-type]: https://cloud.google.com/tpu/docs/multislice-introduction#concepts\n[multi-slice]: https://cloud.google.com/tpu/docs/multislice-introduction\n[troubleshoot-distributed]: https://github.com/AI-Hypercomputer/torchprime/docs/troubleshoot-distributed.md\n","funding_links":[],"categories":[],"sub_categories":[],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fai-hypercomputer%2Ftorchprime","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fai-hypercomputer%2Ftorchprime","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fai-hypercomputer%2Ftorchprime/lists"}