https://github.com/pytorch/xla
Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://github.com/pytorch/xla
compiler deep-learning pytorch xla
Last synced: 6 days ago
JSON representation
Enabling PyTorch on XLA Devices (e.g. Google TPU)
- Host: GitHub
- URL: https://github.com/pytorch/xla
- Owner: pytorch
- License: other
- Created: 2018-11-05T22:42:04.000Z (over 6 years ago)
- Default Branch: master
- Last Pushed: 2025-04-02T00:58:09.000Z (13 days ago)
- Last Synced: 2025-04-02T01:16:17.428Z (13 days ago)
- Topics: compiler, deep-learning, pytorch, xla
- Language: Python
- Homepage: https://pytorch.org/xla
- Size: 70.7 MB
- Stars: 2,574
- Watchers: 53
- Forks: 505
- Open Issues: 806
-
Metadata Files:
- Readme: README.md
- Contributing: CONTRIBUTING.md
- License: LICENSE
- Code of conduct: CODE_OF_CONDUCT.md
- Codeowners: .github/CODEOWNERS
Awesome Lists containing this project
- Awesome-pytorch-list-CNVersion - PyTorch/XLA
- stars - pytorch/xla - Enabling PyTorch on XLA Devices (e.g. Google TPU) (C++)
- Awesome-pytorch-list - PyTorch/XLA
README
# PyTorch/XLA
Current CI status: PyTorch/XLA is a Python package that uses the [XLA deep learning
compiler](https://www.tensorflow.org/xla) to connect the [PyTorch deep learning
framework](https://pytorch.org/) and [Cloud
TPUs](https://cloud.google.com/tpu/). You can try it right now, for free, on a
single Cloud TPU VM with
[Kaggle](https://www.kaggle.com/discussions/product-feedback/369338)!Take a look at one of our [Kaggle
notebooks](https://github.com/pytorch/xla/tree/master/contrib/kaggle) to get
started:* [Stable Diffusion with PyTorch/XLA
2.0](https://github.com/pytorch/xla/blob/master/contrib/kaggle/pytorch-xla-2-0-on-kaggle.ipynb)
* [Distributed PyTorch/XLA
Basics](https://github.com/pytorch/xla/blob/master/contrib/kaggle/distributed-pytorch-xla-basics-with-pjrt.ipynb)## Installation
### TPU
To install PyTorch/XLA stable build in a new TPU VM:
```sh
pip install torch~=2.6.0 'torch_xla[tpu]~=2.6.0' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html
```To install PyTorch/XLA nightly build in a new TPU VM:
```sh
pip install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu
pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev-cp310-cp310-linux_x86_64.whl' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html
```### C++11 ABI builds
**As of 03/18/2025 and starting from Pytorch/XLA 2.7 release, C++11 ABI builds
are the default and we no longer provide wheels built with pre-C++11 ABI.**In Pytorch/XLA 2.6, we'll provide wheels and docker images built with
two C++ ABI flavors: C++11 and pre-C++11. Pre-C++11 is the default to align with
PyTorch upstream, but C++11 ABI wheels and docker images have better lazy tensor
tracing performance.To install C++11 ABI flavored 2.6 wheels (Python 3.10 example):
```sh
pip install torch==2.6.0+cpu.cxx11.abi \
https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%2Bcxx11-cp310-cp310-manylinux_2_28_x86_64.whl \
'torch_xla[tpu]' \
-f https://storage.googleapis.com/libtpu-releases/index.html \
-f https://storage.googleapis.com/libtpu-wheels/index.html \
-f https://download.pytorch.org/whl/torch
```The above command works for Python 3.10. We additionally have Python 3.9 and 3.11
wheels:- 3.9: https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%2Bcxx11-cp39-cp39-manylinux_2_28_x86_64.whl
- 3.10: https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%2Bcxx11-cp310-cp310-manylinux_2_28_x86_64.whl
- 3.11: https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0%2Bcxx11-cp311-cp311-manylinux_2_28_x86_64.whlTo access C++11 ABI flavored docker image:
```
us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11
```If your model is tracing bound (e.g. you see that the host CPU is busy tracing
the model while TPUs are idle), switching to the C++11 ABI wheels/docker images
can improve performance. Mixtral 8x7B benchmarking results on v5p-256, global
batch size 1024:- Pre-C++11 ABI MFU: 33%
- C++ ABI MFU: 39%### GPU Plugin
PyTorch/XLA now provides GPU support through a plugin package similar to `libtpu`:
```
pip install torch~=2.5.0 torch_xla~=2.5.0 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla_cuda_plugin-2.5.0-py3-none-any.whl
```The newest stable version where PyTorch/XLA:GPU wheel is available is `torch_xla`
2.5. We do not offer a PyTorch/XLA:GPU wheel in the PyTorch/XLA 2.6 release. We
understand this is important and plan to [reinstate GPU support](https://github.com/pytorch/xla/issues/8577) by the 2.7 release.
PyTorch/XLA remains an open-source project and we welcome contributions from the
community to help maintain and improve the project. To contribute, please start
with the [contributors guide](https://github.com/pytorch/xla/blob/master/CONTRIBUTING.md).## Github Doc Map
Our github contains many useful docs on working with different aspects of PyTorch XLA, here is a list of useful docs spread around our repository:
- [docs/source/learn](https://github.com/pytorch/xla/tree/master/docs/source/learn): docs for learning concepts associated with XLA, troubleshooting, pjrt, eager mode, and dynamic shape.
- [docs/source/accelerators](https://github.com/pytorch/xla/tree/master/docs/source/accelerators): references to `GPU` and `TPU` accelerator documents.
- [docs/source/perf](https://github.com/pytorch/xla/tree/master/docs/source/perf): documentation about performance specific aspects of PyTorch/XLA such as: `AMP`, `DDP`, `Dynamo`, Fori loop, `FSDP`, quantization, recompilation, and `SPMD`
- [docs/source/features](https://github.com/pytorch/xla/tree/master/docs/source/features): documentation on distributed torch, pallas, scan, stable hlo, and triton.
- [docs/source/contribute](https://github.com/pytorch/xla/tree/master/docs/source/contribute): documents on setting up PyTorch for development, and guides for lowering operations.
- PJRT plugins:
- [CPU](https://github.com/pytorch/xla/blob/master/plugins/cpu/README.md)
- [CUDA](https://github.com/pytorch/xla/blob/master/plugins/cuda/README.md)
- [torchax/docs](https://github.com/pytorch/xla/tree/master/torchax/docs): torchax documents
- [torchax/examples](https://github.com/pytorch/xla/tree/master/torchax/examples): torchax examples## Getting Started
Following here are guides for two modes:
- Single process: one Python interpreter controlling a single GPU/TPU at a time
- Multi process: N Python interpreters are launched, corresponding to N GPU/TPUs
found on the systemAnother mode is SPMD, where one Python interpreter controls all N GPU/TPUs found on
the system. Multi processing is more complex, and is not compatible with SPMD. This
tutorial does not dive into SPMD. For more on that, check our
[SPMD guide](https://github.com/pytorch/xla/blob/master/docs/source/perf/spmd_basic.md).### Simple single process
To update your exisitng training loop, make the following changes:
```diff
+import torch_xladef train(model, training_data, ...):
...
for inputs, labels in train_loader:
+ with torch_xla.step():
inputs, labels = training_data[i]
+ inputs, labels = inputs.to('xla'), labels.to('xla')
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()+ torch_xla.sync()
...if __name__ == '__main__':
...
+ # Move the model paramters to your XLA device
+ model.to('xla')
train(model, training_data, ...)
...
```The changes above should get your model to train on the TPU.
### Multi processing
To update your existing training loop, make the following changes:
```diff
-import torch.multiprocessing as mp
+import torch_xla
+import torch_xla.core.xla_model as xmdef _mp_fn(index):
...+ # Move the model paramters to your XLA device
+ model.to(torch_xla.device())for inputs, labels in train_loader:
+ with torch_xla.step():
+ # Transfer data to the XLA device. This happens asynchronously.
+ inputs, labels = inputs.to(torch_xla.device()), labels.to(torch_xla.device())
optimizer.zero_grad()
outputs = model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
- optimizer.step()
+ # `xm.optimizer_step` combines gradients across replicas
+ xm.optimizer_step(optimizer)if __name__ == '__main__':
- mp.spawn(_mp_fn, args=(), nprocs=world_size)
+ # torch_xla.launch automatically selects the correct world size
+ torch_xla.launch(_mp_fn, args=())
```If you're using `DistributedDataParallel`, make the following changes:
```diff
import torch.distributed as dist
-import torch.multiprocessing as mp
+import torch_xla as xla
+import torch_xla.distributed.xla_backenddef _mp_fn(rank):
...- os.environ['MASTER_ADDR'] = 'localhost'
- os.environ['MASTER_PORT'] = '12355'
- dist.init_process_group("gloo", rank=rank, world_size=world_size)
+ # Rank and world size are inferred from the XLA device runtime
+ dist.init_process_group("xla", init_method='xla://')
+
+ model.to(xm.xla_device())
+ ddp_model = DDP(model, gradient_as_bucket_view=True)- model = model.to(rank)
- ddp_model = DDP(model, device_ids=[rank])for inputs, labels in train_loader:
+ with xla.step():
+ inputs, labels = inputs.to(xla.device()), labels.to(xla.device())
optimizer.zero_grad()
outputs = ddp_model(inputs)
loss = loss_fn(outputs, labels)
loss.backward()
optimizer.step()if __name__ == '__main__':
- mp.spawn(_mp_fn, args=(), nprocs=world_size)
+ xla.launch(_mp_fn, args=())
```Additional information on PyTorch/XLA, including a description of its semantics
and functions, is available at [PyTorch.org](http://pytorch.org/xla/). See the
[API Guide](API_GUIDE.md) for best practices when writing networks that run on
XLA devices (TPU, CUDA, CPU and...).Our comprehensive user guides are available at:
[Documentation for the latest release](https://pytorch.org/xla)
[Documentation for master branch](https://pytorch.org/xla/master)
## PyTorch/XLA tutorials
* [Cloud TPU VM
quickstart](https://cloud.google.com/tpu/docs/run-calculation-pytorch)
* [Cloud TPU Pod slice
quickstart](https://cloud.google.com/tpu/docs/pytorch-pods)
* [Profiling on TPU
VM](https://cloud.google.com/tpu/docs/pytorch-xla-performance-profiling-tpu-vm)
* [GPU guide](docs/gpu.md)## Reference implementations
The [AI-Hypercomputer/tpu-recipes](https://github.com/AI-Hypercomputer/tpu-recipes)
repo. contains examples for training and serving many LLM and diffusion models.## Available docker images and wheels
### Python packages
PyTorch/XLA releases starting with version r2.1 will be available on PyPI. You
can now install the main build with `pip install torch_xla`. To also install the
Cloud TPU plugin corresponding to your installed `torch_xla`, install the optional `tpu` dependencies after installing the main build with```
pip install torch_xla[tpu] \
-f https://storage.googleapis.com/libtpu-wheels/index.html \
-f https://storage.googleapis.com/libtpu-releases/index.html
```GPU and nightly builds are available in our public GCS bucket.
| Version | Cloud GPU VM Wheels |
| --- | ----------- |
| 2.5 (CUDA 12.1 + Python 3.9) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl` |
| 2.5 (CUDA 12.1 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.5 (CUDA 12.1 + Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl` |
| 2.5 (CUDA 12.4 + Python 3.9) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl` |
| 2.5 (CUDA 12.4 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.5 (CUDA 12.4 + Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl` |
| nightly (Python 3.9) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev-cp39-cp39-linux_x86_64.whl` |
| nightly (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev-cp310-cp310-linux_x86_64.whl` |
| nightly (Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev-cp311-cp311-linux_x86_64.whl` |
| nightly (CUDA 12.1 + Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.6.0.dev-cp38-cp38-linux_x86_64.whl` |#### Use nightly build
You can also add `yyyymmdd` like `torch_xla-2.8.0.devyyyymmdd` (or the latest dev version)
to get the nightly wheel of a specified date. Here is an example:```
pip3 install torch==2.8.0.dev20250318+cpu --index-url https://download.pytorch.org/whl/nightly/cpu
pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.8.0.dev20250318-cp310-cp310-linux_x86_64.whl
```The torch wheel version `2.8.0.dev20250318+cpu` can be found at https://download.pytorch.org/whl/nightly/torch/.
older versions
| Version | Cloud TPU VMs Wheel |
|---------|-------------------|
| 2.6 (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.5 (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.4 (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.3 (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.2 (Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.1 (XRT + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/xrt/tpuvm/torch_xla-2.1.0%2Bxrt-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.1 (Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.1.0-cp38-cp38-linux_x86_64.whl` |
| Version | GPU Wheel |
| --- | ----------- |
| 2.5 (CUDA 12.1 + Python 3.9) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl` |
| 2.5 (CUDA 12.1 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.5 (CUDA 12.1 + Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl` |
| 2.5 (CUDA 12.4 + Python 3.9) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl` |
| 2.5 (CUDA 12.4 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.5 (CUDA 12.4 + Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl` |
| 2.4 (CUDA 12.1 + Python 3.9) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp39-cp39-manylinux_2_28_x86_64.whl` |
| 2.4 (CUDA 12.1 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.4 (CUDA 12.1 + Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp311-cp311-manylinux_2_28_x86_64.whl` |
| 2.3 (CUDA 12.1 + Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp38-cp38-manylinux_2_28_x86_64.whl` |
| 2.3 (CUDA 12.1 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.3 (CUDA 12.1 + Python 3.11) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp311-cp311-manylinux_2_28_x86_64.whl` |
| 2.2 (CUDA 12.1 + Python 3.8) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl` |
| 2.2 (CUDA 12.1 + Python 3.10) | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl` |
| 2.1 + CUDA 11.8 | `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/11.8/torch_xla-2.1.0-cp38-cp38-manylinux_2_28_x86_64.whl` |
| nightly + CUDA 12.0 >= 2023/06/27| `https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.0/torch_xla-nightly-cp38-cp38-linux_x86_64.whl` |### Docker
| Version | Cloud TPU VMs Docker |
| --- | ----------- |
| 2.6 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm` |
| 2.6 (C++11 ABI) | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.6.0_3.10_tpuvm_cxx11` |
| 2.5 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_tpuvm` |
| 2.4 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm` |
| 2.3 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_tpuvm` |
| 2.2 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_tpuvm` |
| 2.1 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_tpuvm` |
| nightly python | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm` |
| nightly python (C++11 ABI) | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm_cxx11` |To use the above dockers, please pass `--privileged --net host --shm-size=16G` along. Here is an example:
```bash
docker run --privileged --net host --shm-size=16G -it us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm /bin/bash
```
| Version | GPU CUDA 12.4 Docker |
| --- | ----------- |
| 2.5 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.4` |
| 2.4 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_cuda_12.4` |
| Version | GPU CUDA 12.1 Docker |
| --- | ----------- |
| 2.5 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.1` |
| 2.4 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_cuda_12.1` |
| 2.3 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_cuda_12.1` |
| 2.2 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_cuda_12.1` |
| 2.1 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_12.1` |
| nightly | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1` |
| nightly at date | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1_YYYYMMDD` |
| Version | GPU CUDA 11.8 + Docker |
| --- | ----------- |
| 2.1 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_11.8` |
| 2.0 | `us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.0_3.8_cuda_11.8` |
To run on [compute instances with
GPUs](https://cloud.google.com/compute/docs/gpus/create-vm-with-gpus).## Troubleshooting
If PyTorch/XLA isn't performing as expected, see the [troubleshooting
guide](docs/source/learn/troubleshoot.md), which has suggestions for debugging and optimizing
your network(s).## Providing Feedback
The PyTorch/XLA team is always happy to hear from users and OSS contributors!
The best way to reach out is by filing an issue on this Github. Questions, bug
reports, feature requests, build issues, etc. are all welcome!## Contributing
See the [contribution guide](CONTRIBUTING.md).
## Disclaimer
This repository is jointly operated and maintained by Google, Meta and a
number of individual contributors listed in the
[CONTRIBUTORS](https://github.com/pytorch/xla/graphs/contributors) file. For
questions directed at Meta, please send an email to [email protected]. For
questions directed at Google, please send an email to
[email protected]. For all other questions, please open up an issue
in this repository [here](https://github.com/pytorch/xla/issues).## Additional Reads
You can find additional useful reading materials in
* [Performance debugging on Cloud TPU
VM](https://cloud.google.com/blog/topics/developers-practitioners/pytorchxla-performance-debugging-tpu-vm-part-1)
* [Lazy tensor
intro](https://pytorch.org/blog/understanding-lazytensor-system-performance-with-pytorch-xla-on-cloud-tpu/)
* [Scaling deep learning workloads with PyTorch / XLA and Cloud TPU
VM](https://cloud.google.com/blog/topics/developers-practitioners/scaling-deep-learning-workloads-pytorch-xla-and-cloud-tpu-vm)
* [Scaling PyTorch models on Cloud TPUs with
FSDP](https://pytorch.org/blog/scaling-pytorch-models-on-cloud-tpus-with-fsdp/)## Related Projects
* [OpenXLA](https://github.com/openxla)
* [HuggingFace](https://huggingface.co/docs/accelerate/en/basic_tutorials/tpu)
* [JetStream](https://github.com/google/JetStream-pytorch)