https://github.com/nathom/tpu-pod-tutorial
Get started on TPU pods!
https://github.com/nathom/tpu-pod-tutorial
Last synced: 2 months ago
JSON representation
Get started on TPU pods!
- Host: GitHub
- URL: https://github.com/nathom/tpu-pod-tutorial
- Owner: nathom
- License: mit
- Created: 2025-01-18T20:03:16.000Z (4 months ago)
- Default Branch: main
- Last Pushed: 2025-01-24T21:03:16.000Z (4 months ago)
- Last Synced: 2025-03-09T19:56:08.310Z (2 months ago)
- Size: 38.1 KB
- Stars: 1
- Watchers: 1
- Forks: 2
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# tpu-pod-tutorial
This tutorial will get you started with multi-host VM TPU pods.
If you don't want to follow the tutorial step by step, a reference edited version
is on the `run` branch.## Setup this tutorial
First, fork this repo on GitHub. Then, clone your fork
onto one of the hosts on the TPU VM.```sh
git clone https://github.com/your_username/tpu-pod-tutorial ~/your_folder/tpu-pod-tutorial
```## Setup uv project
To make the builds reproducible across pods, we will be using
`uv` as a package manager. It's like `pip` but *much* better.> If `uv` is not installed, follow the instructions [here](https://docs.astral.sh/uv/getting-started/installation/).
`cd` into `your_folder` and initialize `uv`
```sh
uv init .
```You should see something like:
```
$ uv init .
Initialized project `tpu-pod-tutorial` at `/home/ucsdwanglab/your_folder/tpu-pod-tutorial`
$ ls
LICENSE README.md hello.py pyproject.toml
```Let's add JAX as a dependency
```sh
uv add --prerelease=allow "jax[tpu]" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
``````
Using CPython 3.10.16
Creating virtual environment at: .venv
Resolved 14 packages in 154ms
Installed 13 packages in 61ms
+ certifi==2024.12.14
+ charset-normalizer==3.4.1
+ idna==3.10
+ jax==0.5.0
+ jaxlib==0.5.0
+ libtpu==0.0.8
+ libtpu-nightly==0.1.dev20241010+nightly.cleanup
+ ml-dtypes==0.5.1
+ numpy==2.2.1
+ opt-einsum==3.4.0
+ requests==2.32.3
+ scipy==1.15.1
+ urllib3==2.3.0
$ ls
LICENSE README.md hello.py pyproject.toml uv.lock
```## Run distributed code
Our program is *distributed*. This means that it involves several machines
(we call them hosts), each of whom run a process, and communicate with each other over
network cables. Crucially, they **do not share a filesystem**, so we need to distribute all of our code and data
before we can run.Let's begin:
First, let's make sure we can connect to all the hosts. To do this, we will use the `eopod` tool,
which can be run with `uv`'s execuation tool `uvx`.```sh
uvx eopod run echo "Hello, world!"
``````
Started at: 2025-01-18 20:27:48
Executing: echo Hello, world!
Executing command on worker all: echo Hello, world!
Using ssh batch size of 1. Attempting to SSH into 1 nodes with a total of 4 workers.
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
Hello, world!
Hello, world!
Hello, world!
Hello, world!
Command completed successfullyCommand completed successfully!
Completed at: 2025-01-18 20:27:52
Duration: 0:00:03.858619
```You should see one `Hello, world!` per host. I have 4 hosts, so I see 4.
Running Python code is a bit more involved. First, let's install uv
on all hosts```sh
uvx eopod run 'curl -LsSf https://astral.sh/uv/install.sh | sh'
``````
Started at: 2025-01-18 21:02:46
Executing: curl -LsSf https://astral.sh/uv/install.sh | sh
Executing command on worker all: curl -LsSf
https://astral.sh/uv/install.sh | sh
Using ssh batch size of 1. Attempting to SSH into 1 nodes with a total of 4 workers.
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
downloading uv 0.5.21 x86_64-unknown-linux-gnu
downloading uv 0.5.21 x86_64-unknown-linux-gnu
downloading uv 0.5.21 x86_64-unknown-linux-gnu
downloading uv 0.5.21 x86_64-unknown-linux-gnu
no checksums to verify
no checksums to verify
no checksums to verify
installing to /home/ucsdwanglab/.local/bin
no checksums to verify
installing to /home/ucsdwanglab/.local/bin
uv
uvx
everything's installed!
uv
uvx
everything's installed!
installing to /home/ucsdwanglab/.local/bin
uv
uvx
everything's installed!
installing to /home/ucsdwanglab/.local/bin
uv
uvx
everything's installed!
Command completed successfully
```And verify that `uv` is available:
Edit `hello.py` to contain
```python
import jaxjax.distributed.initialize()
if jax.process_index() == 0:
print(jax.device_count())
```This will initialize JAX's distributed computing system, and print the total number
of detected devices from Host 0.Remember that our hosts don't share a filesystem, so we need to sync this code among all the machines. We'll
use `git` to do this```sh
git add . && git commit -am "Add hello.py" && git push
```Next, we clone this repo on all hosts:
```sh
uvx eopod run git clone https://github.com/your_username/tpu-pod-tutorial ~/your_folder/tpu-pod-tutorial
```Make sure we get a `Success` from every host:
```sh
uvx eopod run '[ -d ~/your_folder/tpu-pod-tutorial ] && echo Success || echo Fail'
```Now, we can run the program. Create `run.sh` with contents
```bash
#!/bin/bashif [ $# -lt 1 ]; then
echo "Usage: $0 "
exit 1
fiPYTHON_FILE="$1"
eopod kill-tpu --force
eopod run "
export PATH=$PATH &&
cd ~/your_folder/tpu-pod-tutorial &&
git pull &&
uv run --prerelease allow $PYTHON_FILE
"
```And run
```sh
bash run.sh hello.py
``````
$ bash run.sh
⠋ Scanning for TPU processes...Fetching TPU status...
⠙ Scanning for TPU processes...TPU state: READY
Executing command on worker 0: ps aux | grep -E
'python|jax|tensorflow' | grep -v grep | awk '{print $2}' | while read
pid; do if [ -d /proc/$pid ] && grep -q 'accel' /proc/$pid/maps
2>/dev/null; then echo $pid; fi; done
Executing command on worker 1: ps aux | grep -E
'python|jax|tensorflow' | grep -v grep | awk '{print $2}' | while read
pid; do if [ -d /proc/$pid ] && grep -q 'accel' /proc/$pid/maps
2>/dev/null; then echo $pid; fi; done
Executing command on worker 2: ps aux | grep -E
'python|jax|tensorflow' | grep -v grep | awk '{print $2}' | while read
pid; do if [ -d /proc/$pid ] && grep -q 'accel' /proc/$pid/maps
2>/dev/null; then echo $pid; fi; done
Executing command on worker 3: ps aux | grep -E
'python|jax|tensorflow' | grep -v grep | awk '{print $2}' | while read
pid; do if [ -d /proc/$pid ] && grep -q 'accel' /proc/$pid/maps
2>/dev/null; then echo $pid; fi; done
⠏ Scanning for TPU processes...Command completed successfully
Command completed successfully
⠋ Scanning for TPU processes...Command completed successfully
⠙ Scanning for TPU processes...Command completed successfully
No TPU processes found.
⠙ Scanning for TPU processes...
Started at: 2025-01-18 21:00:12
Executing:
export
PATH=/home/ucsdwanglab/.local/bin:/home/ucsdwanglab/.local/bin:/usr/lo
cal/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/
local/games:/snap/bin:/home/ucsdwanglab/.local/bin &&
cd ~/your_folder/tpu-pod-tutorial &&
git checkout run &&
git pull &&
uv run --prerelease allow hello.pyExecuting command on worker all:
export
PATH=/home/ucsdwanglab/.local/bin:/home/ucsdwanglab/.local/bin:/usr/lo
cal/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/
local/games:/snap/bin:/home/ucsdwanglab/.local/bin &&
cd ~/your_folder/tpu-pod-tutorial &&
git checkout run &&
git pull &&
uv run --prerelease allow hello.pyUsing ssh batch size of 1. Attempting to SSH into 1 nodes with a total of 4 workers.
SSH: Attempting to connect to worker 0...
SSH: Attempting to connect to worker 1...
SSH: Attempting to connect to worker 2...
SSH: Attempting to connect to worker 3...
Already on 'run'
M README.md
Your branch is up to date with 'origin/run'.
Already on 'run'
Your branch is up to date with 'origin/run'.
Already on 'run'
Your branch is up to date with 'origin/run'.
Already on 'run'
Your branch is up to date with 'origin/run'.
Already up to date.
Already up to date.
Already up to date.
Already up to date.
16
Command completed successfullyCommand completed successfully!
Completed at: 2025-01-18 21:00:20
Duration: 0:00:08.020470
```We see there are 16 detected devices on the `v4-32` pod!
Let's try running a distributed computation. Create a file `pmap.py` in the repo
with the following contents```python
# The following code snippet will be run on all TPU hosts
import jax# The total number of TPU cores in the Pod
device_count = jax.device_count()# The number of TPU cores attached to this host
local_device_count = jax.local_device_count()# The psum is performed over all mapped devices across the Pod
xs = jax.numpy.ones(jax.local_device_count())
r = jax.pmap(lambda x: jax.lax.psum(x, 'i'), axis_name='i')(xs)# Print from a single host to avoid duplicated output
if jax.process_index() == 0:
print('global device count:', jax.device_count())
print('local device count:', jax.local_device_count())
print('pmap result:', r)
```Commit and push
```sh
git add pmap.py && git commit -am "Add pmap.py" && git push
```And run
```sh
bash run.sh pmap.py
```After all the setup output, it should show something like
```
global device count: 16
local device count: 4
pmap result: [16. 16. 16. 16.]
```