Pytorch/xla: Enabling PyTorch On Google TPU - GitHub
Current CI status:
PyTorch/XLA is a Python package that uses the XLA deep learning compiler to connect the PyTorch deep learning framework and Cloud TPUs. You can try it right now, for free, on a single Cloud TPU VM with Kaggle!
Take a look at one of our Kaggle notebooks to get started:
- Stable Diffusion with PyTorch/XLA 2.0
- Distributed PyTorch/XLA Basics
Installation
TPU
To install PyTorch/XLA stable build in a new TPU VM:
pip install torch~=2.5.0 torch_xla[tpu]~=2.5.0 -f https://storage.googleapis.com/libtpu-releases/index.htmlTo install PyTorch/XLA nightly build in a new TPU VM:
pip3 install --pre torch torchvision --index-url https://download.pytorch.org/whl/nightly/cpu pip install 'torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl' -f https://storage.googleapis.com/libtpu-releases/index.htmlGPU Plugin
PyTorch/XLA now provides GPU support through a plugin package similar to libtpu:
pip install torch~=2.5.0 torch_xla~=2.5.0 https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla_cuda_plugin-2.5.0-py3-none-any.whlGetting Started
To update your existing training loop, make the following changes:
-import torch.multiprocessing as mp +import torch_xla as xla +import torch_xla.core.xla_model as xm def _mp_fn(index): ... + # Move the model paramters to your XLA device + model.to(xla.device()) for inputs, labels in train_loader: + with xla.step(): + # Transfer data to the XLA device. This happens asynchronously. + inputs, labels = inputs.to(xla.device()), labels.to(xla.device()) optimizer.zero_grad() outputs = model(inputs) loss = loss_fn(outputs, labels) loss.backward() - optimizer.step() + # `xm.optimizer_step` combines gradients across replicas + xm.optimizer_step(optimizer) if __name__ == '__main__': - mp.spawn(_mp_fn, args=(), nprocs=world_size) + # xla.launch automatically selects the correct world size + xla.launch(_mp_fn, args=())If you're using DistributedDataParallel, make the following changes:
import torch.distributed as dist -import torch.multiprocessing as mp +import torch_xla as xla +import torch_xla.distributed.xla_backend def _mp_fn(rank): ... - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12355' - dist.init_process_group("gloo", rank=rank, world_size=world_size) + # Rank and world size are inferred from the XLA device runtime + dist.init_process_group("xla", init_method='xla://') + + model.to(xm.xla_device()) + # `gradient_as_bucket_view=True` required for XLA + ddp_model = DDP(model, gradient_as_bucket_view=True) - model = model.to(rank) - ddp_model = DDP(model, device_ids=[rank]) for inputs, labels in train_loader: + with xla.step(): + inputs, labels = inputs.to(xla.device()), labels.to(xla.device()) optimizer.zero_grad() outputs = ddp_model(inputs) loss = loss_fn(outputs, labels) loss.backward() optimizer.step() if __name__ == '__main__': - mp.spawn(_mp_fn, args=(), nprocs=world_size) + xla.launch(_mp_fn, args=())Additional information on PyTorch/XLA, including a description of its semantics and functions, is available at PyTorch.org. See the API Guide for best practices when writing networks that run on XLA devices (TPU, CUDA, CPU and...).
Our comprehensive user guides are available at:
Documentation for the latest release
Documentation for master branch
PyTorch/XLA tutorials
- Cloud TPU VM quickstart
- Cloud TPU Pod slice quickstart
- Profiling on TPU VM
- GPU guide
Available docker images and wheels
Python packages
PyTorch/XLA releases starting with version r2.1 will be available on PyPI. You can now install the main build with pip install torch_xla. To also install the Cloud TPU plugin corresponding to your installed torch_xla, install the optional tpu dependencies after installing the main build with
pip install torch_xla[tpu] -f https://storage.googleapis.com/libtpu-releases/index.htmlGPU and nightly builds are available in our public GCS bucket.
Version | Cloud GPU VM Wheels |
---|---|
2.5 (CUDA 12.1 + Python 3.9) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.1 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.4 + Python 3.9) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.4 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.4 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl |
nightly (Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp38-cp38-linux_x86_64.whl |
nightly (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev-cp310-cp310-linux_x86_64.whl |
nightly (CUDA 12.1 + Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.6.0.dev-cp38-cp38-linux_x86_64.whl |
The torch wheel version 2.6.0.dev20240925+cpu can be found at https://download.pytorch.org/whl/nightly/torch/.
Use nightly build after 08/20/2024
You can also add yyyymmdd after torch_xla-2.6.0.dev to get the nightly wheel of a specified date. Here is an example:
pip3 install torch==2.5.0.dev20240820+cpu --index-url https://download.pytorch.org/whl/nightly/cpu pip3 install https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.5.0.dev20240820-cp310-cp310-linux_x86_64.whlThe torch wheel version 2.6.0.dev20240925+cpu can be found at https://download.pytorch.org/whl/nightly/torch/.
older versionsVersion | Cloud TPU VMs Wheel |
---|---|
2.4 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.3 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.2 (Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.1 (XRT + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/xrt/tpuvm/torch_xla-2.1.0%2Bxrt-cp310-cp310-manylinux_2_28_x86_64.whl |
2.1 (Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.1.0-cp38-cp38-linux_x86_64.whl |
Version | GPU Wheel |
---|---|
2.5 (CUDA 12.1 + Python 3.9) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.1 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.4 + Python 3.9) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp39-cp39-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.4 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.5 (CUDA 12.4 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.4/torch_xla-2.5.0-cp311-cp311-manylinux_2_28_x86_64.whl |
2.4 (CUDA 12.1 + Python 3.9) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp39-cp39-manylinux_2_28_x86_64.whl |
2.4 (CUDA 12.1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.4 (CUDA 12.1 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.4.0-cp311-cp311-manylinux_2_28_x86_64.whl |
2.3 (CUDA 12.1 + Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp38-cp38-manylinux_2_28_x86_64.whl |
2.3 (CUDA 12.1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.3 (CUDA 12.1 + Python 3.11) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.3.0-cp311-cp311-manylinux_2_28_x86_64.whl |
2.2 (CUDA 12.1 + Python 3.8) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp38-cp38-manylinux_2_28_x86_64.whl |
2.2 (CUDA 12.1 + Python 3.10) | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.1/torch_xla-2.2.0-cp310-cp310-manylinux_2_28_x86_64.whl |
2.1 + CUDA 11.8 | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/11.8/torch_xla-2.1.0-cp38-cp38-manylinux_2_28_x86_64.whl |
nightly + CUDA 12.0 >= 2023/06/27 | https://storage.googleapis.com/pytorch-xla-releases/wheels/cuda/12.0/torch_xla-nightly-cp38-cp38-linux_x86_64.whl |
Docker
Version | Cloud TPU VMs Docker |
---|---|
2.5 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_tpuvm |
2.4 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_tpuvm |
2.3 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_tpuvm |
2.2 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_tpuvm |
2.1 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_tpuvm |
nightly python | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm |
To use the above dockers, please pass --privileged --net host --shm-size=16G along. Here is an example:
docker run --privileged --net host --shm-size=16G -it us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.10_tpuvm /bin/bashVersion | GPU CUDA 12.4 Docker |
---|---|
2.5 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.4 |
2.4 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_cuda_12.4 |
Version | GPU CUDA 12.1 Docker |
---|---|
2.5 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.5.0_3.10_cuda_12.1 |
2.4 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.4.0_3.10_cuda_12.1 |
2.3 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.3.0_3.10_cuda_12.1 |
2.2 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.2.0_3.10_cuda_12.1 |
2.1 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_12.1 |
nightly | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1 |
nightly at date | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:nightly_3.8_cuda_12.1_YYYYMMDD |
Version | GPU CUDA 11.8 + Docker |
---|---|
2.1 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.1.0_3.10_cuda_11.8 |
2.0 | us-central1-docker.pkg.dev/tpu-pytorch-releases/docker/xla:r2.0_3.8_cuda_11.8 |
To run on compute instances with GPUs.
Troubleshooting
If PyTorch/XLA isn't performing as expected, see the troubleshooting guide, which has suggestions for debugging and optimizing your network(s).
Providing Feedback
The PyTorch/XLA team is always happy to hear from users and OSS contributors! The best way to reach out is by filing an issue on this Github. Questions, bug reports, feature requests, build issues, etc. are all welcome!
Contributing
See the contribution guide.
Disclaimer
This repository is jointly operated and maintained by Google, Meta and a number of individual contributors listed in the CONTRIBUTORS file. For questions directed at Meta, please send an email to opensource@fb.com. For questions directed at Google, please send an email to pytorch-xla@googlegroups.com. For all other questions, please open up an issue in this repository here.
Additional Reads
You can find additional useful reading materials in
- Performance debugging on Cloud TPU VM
- Lazy tensor intro
- Scaling deep learning workloads with PyTorch / XLA and Cloud TPU VM
- Scaling PyTorch models on Cloud TPUs with FSDP
Related Projects
- OpenXLA
- HuggingFace
- JetStream
Từ khóa » Thư Viện Pytorch
-
Cơ Bản Về Pytorch Và Bài Toán Linear Regression - Viblo
-
Hướng Dẫn Tất Tần Tật Về Pytorch để Làm Các Bài Toán Về AI - Viblo
-
Bài 1: Giới Thiệu - Pytorch Cơ Bản - VnCoder
-
1. Pytorch Là Gì? - Khoa Học Dữ Liệu
-
Pytorch Là Gì? Hướng Dẫn Sử Dụng Pytorch Giải Các Bài Toán Về AI
-
Pytorch | Deep Learning Cơ Bản
-
PyTorch
-
PyTorch Và TensorFlow: Nên Chọn Framework Nào Trong Năm 2022?
-
Cài Đặt Pytorch – Lập Trình Neural Network Với Pytorch Bài 2 - TEK4
-
PyTorch Cơ Bản | Hao Phu Phan
-
Hướng Dẫn PyTorch - W3seo Tìm Hiểu Về Về Framework Trong Pytorch
-
NVIDIA Giúp Nghiên Cứu 3D Deep Learning Dễ Dàng Với Thư Viện ...
-
PyTorch-la-gi - VNOI