{"id":18843465,"url":"https://github.com/ai-hypercomputer/jetstream-pytorch","last_synced_at":"2025-10-27T20:33:20.286Z","repository":{"id":231199367,"uuid":"781139609","full_name":"AI-Hypercomputer/jetstream-pytorch","owner":"AI-Hypercomputer","description":"PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference\"","archived":false,"fork":false,"pushed_at":"2025-03-27T20:55:55.000Z","size":1480,"stargazers_count":60,"open_issues_count":13,"forks_count":17,"subscribers_count":8,"default_branch":"main","last_synced_at":"2025-04-30T03:36:20.641Z","etag":null,"topics":["attention","batching","gemma","inference","llama","llama2","llm","llm-inference","model-serving","pytorch","tpu"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"apache-2.0","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/AI-Hypercomputer.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":"CONTRIBUTING.md","funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null,"governance":null,"roadmap":null,"authors":null,"dei":null,"publiccode":null,"codemeta":null,"zenodo":null}},"created_at":"2024-04-02T20:27:00.000Z","updated_at":"2025-04-23T22:15:31.000Z","dependencies_parsed_at":"2024-05-28T19:58:04.464Z","dependency_job_id":"e1588735-6b62-4ef8-86c0-8b519d1d8799","html_url":"https://github.com/AI-Hypercomputer/jetstream-pytorch","commit_stats":{"total_commits":210,"total_committers":15,"mean_commits":14.0,"dds":0.6476190476190475,"last_synced_commit":"36f8a23b873c6abc8f9c510ef0c31afef117c70d"},"previous_names":["google/jetstream-pytorch","ai-hypercomputer/jetstream-pytorch"],"tags_count":6,"template":false,"template_full_name":null,"repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/AI-Hypercomputer%2Fjetstream-pytorch","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/AI-Hypercomputer%2Fjetstream-pytorch/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/AI-Hypercomputer%2Fjetstream-pytorch/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/AI-Hypercomputer%2Fjetstream-pytorch/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/AI-Hypercomputer","download_url":"https://codeload.github.com/AI-Hypercomputer/jetstream-pytorch/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":252831462,"owners_count":21810808,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["attention","batching","gemma","inference","llama","llama2","llm","llm-inference","model-serving","pytorch","tpu"],"created_at":"2024-11-08T02:57:53.221Z","updated_at":"2025-10-27T20:33:20.185Z","avatar_url":"https://github.com/AI-Hypercomputer.png","language":"Python","readme":"# Jetstream-PyTorch\nJetStream Engine implementation in PyTorch\n\n# Latest Release:\n\nThe latest release version is tagged with `jetstream-v0.2.3`. If you are running the release version\nPlease follow the README of the that version here:\nhttps://github.com/google/jetstream-pytorch/blob/jetstream-v0.2.3/README.md\n\nCommandline Flags might have changed between the release version to HEAD.\n\n# Outline\n\n1. Ssh to Cloud TPU VM (using v5e-8 TPU VM)\n   a. Create a Cloud TPU VM if you haven’t\n2. Download jetstream-pytorch github repo\n3. Run the server\n4. Run benchmarks\n5. Typical Errors\n\n# Ssh to Cloud TPU VM (using v5e-8 TPU VM)\n\n```bash\ngcloud compute config-ssh\ngcloud compute tpus tpu-vm ssh \"your-tpu-vm\" --project \"your-project\" --zone \"your-project-zone\"\n```\n## Create a Cloud TPU VM in a GCP project  if you haven’t\nFollow the steps in\n* https://cloud.google.com/tpu/docs/managing-tpus-tpu-vm\n\n# Clone repo and install dependencies \n\n## Get the jetstream-pytorch code\n```bash\ngit clone https://github.com/google/jetstream-pytorch.git\ngit checkout jetstream-v0.2.4\n```\n\n(optional) Create a virtual env using `venv` or `conda` and activate it.\n\n## 2. Run installation script:\n\n```bash\ncd jetstream-pytorch\nsource install_everything.sh\n```\n\n\n# Run jetstream pytorch\n\n## List out supported models\n\n```\njpt list\n```\n\nThis will print out list of support models and variants:\n\n```\nmeta-llama/Llama-2-7b-chat-hf\nmeta-llama/Llama-2-7b-hf\nmeta-llama/Llama-2-13b-chat-hf\nmeta-llama/Llama-2-13b-hf\nmeta-llama/Llama-2-70b-hf\nmeta-llama/Llama-2-70b-chat-hf\nmeta-llama/Meta-Llama-3-8B\nmeta-llama/Meta-Llama-3-8B-Instruct\nmeta-llama/Meta-Llama-3-70B\nmeta-llama/Meta-Llama-3-70B-Instruct\nmeta-llama/Llama-3.1-8B\nmeta-llama/Llama-3.1-8B-Instruct\nmeta-llama/Llama-3.2-1B\nmeta-llama/Llama-3.2-1B-Instruct\nmeta-llama/Llama-3.3-70B\nmeta-llama/Llama-3.3-70B-Instruct\ngoogle/gemma-2b\ngoogle/gemma-2b-it\ngoogle/gemma-7b\ngoogle/gemma-7b-it\nmistralai/Mixtral-8x7B-v0.1\nmistralai/Mixtral-8x7B-Instruct-v0.1\n```\n\nTo run jetstream-pytorch server with one model:\n\n```\njpt serve --model_id meta-llama/Meta-Llama-3-8B-Instruct\n```\n\nIf it's the first time you run this model, it will download weights from \nHuggingFace. \n\nHuggingFace's Llama3 weights are gated, so you need to either run \n`huggingface-cli login` to set your token, OR, pass your hf_token explicitly.\n\nTo pass hf token explicitly, add `--hf_token` flag\n```\njpt serve --model_id meta-llama/Meta-Llama-3-8B-Instruct --hf_token=...\n```\n\nTo login using huggingface hub, run:\n\n```\npip install -U \"huggingface_hub[cli]\"\nhuggingface-cli login\n```\nThen follow its prompt.\n\nAfter the weights are downloaded,\nNext time when you run this `--hf_token` will no longer be required.\n\nTo run this model in `int8` quantization, add `--quantize_weights=1`.\nQuantization will be done on the flight as the weight loads.\n\nWeights downloaded from HuggingFace will be stored by default in `checkpoints` folder.\nin the place where `jpt` is executed.\n\nYou can change where the weights are stored with `--working_dir` flag.\n\nIf you wish to use your own checkpoint, then, place them inside \nof the `checkpoints/\u003corg\u003e/\u003cmodel\u003e/hf_original` dir (or the corresponding subdir in `--working_dir`). For example,\nLlama3 checkpoints will be at `checkpoints/meta-llama/Llama-2-7b-hf/hf_original/*.safetensors`. You can replace these files with modified\nweights in HuggingFace format. \n\n## Send one request\n\nJetstream-pytorch uses gRPC for handling requests, the script below demonstrates how to\nsend gRPC in Python. You can also use other gPRC clients.\n\n```python\nimport requests\nimport os\nimport grpc\n\nfrom jetstream.core.proto import jetstream_pb2\nfrom jetstream.core.proto import jetstream_pb2_grpc\n\nprompt = \"What are the top 5 languages?\"\n\nchannel = grpc.insecure_channel(\"localhost:8888\")\nstub = jetstream_pb2_grpc.OrchestratorStub(channel)\n\nrequest = jetstream_pb2.DecodeRequest(\n    text_content=jetstream_pb2.DecodeRequest.TextContent(\n        text=prompt\n    ),\n    priority=0,\n    max_tokens=2000,\n)\n\nresponse = stub.Decode(request)\noutput = []\nfor resp in response:\n  output.extend(resp.stream_content.samples[0].text)\n\ntext_output = \"\".join(output)\nprint(f\"Prompt: {prompt}\")\nprint(f\"Response: {text_output}\")\n```\n\n\n# Run the server with ray\nBelow are steps run server with ray:\n1. Ssh to Cloud Multiple Host TPU VM (v5e-16 TPU VM)\n2. Step 2 to step 5 in Outline \n3. Setup ray cluster \n4. Run server with ray\n\n## Setup Ray Cluster \nLogin host 0 VM, start ray head with below command: \n\n```bash\n\nray start --head\n\n```\n\nLogin other host VMs, start ray head with below command:\n\n```bash\n\nray start --address='$ip:$port'\n\n```\n\nNote: Get address ip and port information from ray head.\n\n## Run server with ray\n\nHere is an example to run the server with ray for llama2 7B model:\n\n```bash\nexport DISABLE_XLA2_PJRT_TEST=\"true\"\npython run_server_with_ray.py --tpu_chips=16 --num_hosts=4 --worker_chips=4 -model_name=$model_name          --size=7b --batch_size=96 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir   --tokenizer_path=$tokenizer_path --sharding_config=\"default_shardings/llama.yaml\"\n```\n\n# Run benchmark\nStart the server and then go to the deps/JetStream folder (downloaded during `install_everything.sh`)\n\n```bash\ncd deps/JetStream\nwget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json\nexport dataset_path=ShareGPT_V3_unfiltered_cleaned_split.json\npython benchmarks/benchmark_serving.py --tokenizer $tokenizer_path --num-prompts 2000  --dataset-path  $dataset_path --dataset sharegpt --save-request-outputs --warmup-mode=sampled --model=$model_name\n```\nPlease look at `deps/JetStream/benchmarks/README.md` for more information.\n\n\n\n## Run server with Ray Serve\n\n### Prerequisites\n\nIf running on GKE:\n\n1. Follow instructions on [this link](https://github.com/GoogleCloudPlatform/ai-on-gke/tree/main/ray-on-gke/guides/tpu) to setup a GKE cluster and the TPU webhook.\n2. Follow instructions\n   [here](https://cloud.google.com/kubernetes-engine/docs/how-to/persistent-volumes/cloud-storage-fuse-csi-driver)\n   to enable GCSFuse for your cluster. This will be needed to store the\n   converted weights.\n3. Deploy one of the sample Kuberay cluster configurations:\n```bash\nkubectl apply -f kuberay/manifests/ray-cluster.tpu-v4-singlehost.yaml\n```\nor\n```bash\nkubectl apply -f kuberay/manifests/ray-cluster.tpu-v4-multihost.yaml\n```\n\n\n### Start a Ray Serve deployment\n\nSingle-host (Llama2 7B):\n\n```bash\nexport RAY_ADDRESS=http://localhost:8265\n\nkubectl port-forward svc/example-cluster-kuberay-head-svc 8265:8265 \u0026\n\nray job submit --runtime-env-json='{\"working_dir\": \".\"}' -- python run_ray_serve_interleave.py  --tpu_chips=4 --num_hosts=1 --size=7b --model_name=llama-2 --batch_size=32 --max_cache_length=2048 --tokenizer_path=/llama/tokenizer.model --checkpoint_path=/llama/ckpt --quantize_weights=True --quantize_type=\"int8_per_channel\" --quantize_kv_cache=True --sharding_config=\"default_shardings/llama.yaml\"\n```\n\nMulti-host (Llama2 70B):\n\n```bash\nexport RAY_ADDRESS=http://localhost:8265\n\nkubectl port-forward svc/example-cluster-kuberay-head-svc 8265:8265 \u0026\n\nray job submit --runtime-env-json='{\"working_dir\": \".\"}' -- python run_ray_serve_interleave.py  --tpu_chips=8 --num_hosts=2 --size=70b --model_name=llama-2 --batch_size=8 --max_cache_length=2048 --tokenizer_path=/llama/tokenizer.model --checkpoint_path=/llama/ckpt --quantize_weights=True --quantize_type=\"int8_per_channel\" --quantize_kv_cache=True --sharding_config=\"default_shardings/llama.yaml\"\n```\n\n### Sending an inference request\n\nPort-forward to port 8888 for gRPC:\n```\nkubectl port-forward svc/example-cluster-kuberay-head-svc 8888:8888 \u0026\n```\n\nSample python script:\n\n```python\nimport requests\nimport os\nimport grpc\n\nfrom jetstream.core.proto import jetstream_pb2\nfrom jetstream.core.proto import jetstream_pb2_grpc\n\nprompt = \"What are the top 5 languages?\"\n\nchannel = grpc.insecure_channel(\"localhost:8888\")\nstub = jetstream_pb2_grpc.OrchestratorStub(channel)\n\nrequest = jetstream_pb2.DecodeRequest(\n    text_content=jetstream_pb2.DecodeRequest.TextContent(\n        text=prompt\n    ),\n    priority=0,\n    max_tokens=2000,\n)\n\nresponse = stub.Decode(request)\noutput = []\nfor resp in response:\n  output.extend(resp.stream_content.samples[0].text)\n\ntext_output = \"\".join(output)\nprint(f\"Prompt: {prompt}\")\nprint(f\"Response: {text_output}\")\n```\n\n\n\n# Typical Errors\n\n## Unexpected keyword argument 'device'\n\nFix:\n* Uninstall jax and jaxlib dependencies \n* Reinstall using `source install_everything.sh\n\n## Out of memory\n\nFix:\n* Use smaller batch size\n* Use quantization\n\n# Links\n\n## JetStream\n* https://github.com/google/JetStream\n\n## MaxText\n* https://github.com/google/maxtext\n\n\n","funding_links":[],"categories":[],"sub_categories":[],"project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fai-hypercomputer%2Fjetstream-pytorch","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Fai-hypercomputer%2Fjetstream-pytorch","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Fai-hypercomputer%2Fjetstream-pytorch/lists"}