Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/verisilicon/vsi-pjrt-plugin
The pjrt-plugin implementation for VeriSIlicon NPU IP for Tensorflow/PyTorch/Other ecosystem.
https://github.com/verisilicon/vsi-pjrt-plugin
Last synced: 20 days ago
JSON representation
The pjrt-plugin implementation for VeriSIlicon NPU IP for Tensorflow/PyTorch/Other ecosystem.
- Host: GitHub
- URL: https://github.com/verisilicon/vsi-pjrt-plugin
- Owner: VeriSilicon
- License: mit
- Created: 2024-02-18T06:27:43.000Z (11 months ago)
- Default Branch: main
- Last Pushed: 2024-04-28T09:16:54.000Z (9 months ago)
- Last Synced: 2024-04-28T10:27:04.551Z (9 months ago)
- Language: C++
- Size: 146 KB
- Stars: 3
- Watchers: 2
- Forks: 1
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# VSI PJRT Plugin
This is a PJRT client implementation for the VeriSilicon NPU/GPU platform as a dynamic plugin.
## Building
### Install PyTorch
```shell
pip3 install torch~=2.3.0 torchvision~=0.18.0 --index-url https://download.pytorch.org/whl/cpu
```### Install PyTorch/XLA
```shell
pip3 install torch_xla~=2.3.0
```### Build PJRT plugin
You can build the plugin dynamic library using bazel:
```shell
bazel build //xla:vsi_pjrt_plugin
```The built plugin library is located at `bazel-bin/xla/pjrt/c/pjrt_c_api_vsi_plugin.so`, for development purpose, you can create a symlink to the built plugin library:
```shell
cd torch_xla_vsi_plugin/lib
ln -s ../../bazel-bin/xla/pjrt/c/pjrt_c_api_vsi_plugin.so pjrt_c_api_vsi_plugin.so
```And add the `torch_xla_vsi_plugin` dir to `PYTHONPATH`.
Or you can build and bundle the plugin as a pip wheel.
```shell
# Build wheel.
pip wheel . -v
# Or install directly.
pip install . -v
```### Generate compilation database
This command will generate a `compile_commands.json` in current workspace for Clang linter tools.
```shell
bazel run :refresh_compile_commands
```## Usage
### Set environment variables
```shell
# Locate the vsi unified driver.
VIVANTE_SDK_DIR=${VIV_SDK_INSTALL_PATH}
LD_LIBRARY_PATH=${VIVANTE_SDK_DIR}/[lib|lib64|drivers]
# Need to specify hardware PID if using simulator driver.
VSIMULATOR_CONFIG=VIP9000ULSI_PID0XBA# Map PyTorch Long type to HLO s32 type.
XLA_USE_32BIT_LONG=1
# Since there's no StableHLO -> HLO conversion of Q/DQ ops,
# need to disable HLO -> StableHLO -> HLO roundtrip.
XLA_STABLEHLO_COMPILE=0
```### Load plugin dynamically
```python
from torch_xla.experimental import plugins
import torch_xla.core.xla_model as xm
import torch_xla.runtime as xrfrom torch_xla_vsi_plugin import VsiPlugin
# Use dynamic PJRT plugin.
vsi_plugin = VsiPlugin()
plugins.use_dynamic_plugins()
plugins.register_plugin("vsi", vsi_plugin)
xr.set_device_type("vsi")# Now you can use the npu device for PyTorch modules and tensors.
xla_device = xm.xla_device()```
See more examples in `examples/`.