Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/tchordia/ray-serve-deepspeed

Run deepspeed on ray serve
https://github.com/tchordia/ray-serve-deepspeed

Last synced: about 2 months ago
JSON representation

Run deepspeed on ray serve

Awesome Lists containing this project

README

        

# Ray-DeepSpeed-Inference

*EXPERIMENTAL AND NOT PRODUCTION READY! Many rough edges.*

Based on https://github.com/microsoft/DeepSpeedExamples/tree/master/inference/huggingface/text-generation

## How to run

Runs OPT-66b inference on a cluster composed of g4dn nodes (in my tests, 3 x g4dn.12xlarge, giving a total of 12 GPUs). You can also run it on 12 x g4dn.4xlarge.

```bash
python run_on_every_node.py download_model "s3://large-dl-models-mirror/models--anyscale--opt-66b-resharded/main/" "~/model"

python deepspeed_inference_actors.py --name "facebook/opt-66b" --checkpoint_path "~/model" --batch_size 1 --ds_inference --use_kernel --use_meta_tensor --num_worker_groups 1 --num_gpus_per_worker_group 12
```

## How it works

This repository demonstrates how to use [DeepSpeed Inference](https://www.deepspeed.ai/tutorials/inference-tutorial/) with [Ray](https://ray.io/) for scalable batch inference. The combination of these two tools allows for efficient generation of text with large language models, including models as large as OPT-66b.

DeepSpeed Inference utilizes automatic model parallelism to distribute the model across multiple GPUs. Ray handles the scheduling and orchestration of the workload.

There are three key parts to the code:
1. `deepspeed_inference_actors.py` (the entrypoint) generates a sample Ray Dataset and uses `ray.train.batch_predictor.BatchPredictor` with a custom `DeepSpeedPredictor`. The `BatchPredictor` spawns `num_worker_groups` `DeepSpeedPredictor` actors, each recieving a share of the data.
2. `deepspeed_predictor.py` contains the code for the `DeepSpeedPredictor`. Each `DeepSpeedPredictor` actor spawns `num_gpus_per_worker_group` worker actors (`PredictionWorker`), connected together via a `torch.distributed` backend, as required by DeepSpeed. Once initialized, the DeepSpeed model is ready for prediction.
3. `deepspeed_utils.py` contains code based on a DeepSpeed example that is used by `PredictionWorkers`.

In other words, a `DeepSpeedPredictor` creates a worker group of `PredictionWorker`, which share a single model. A worker group is inelastic (if one worker fails, the entire group fails). This is similar to how Ray Train works (in fact, the logic can be implemented using Ray Train private APIs instead of `PredictionWorker`).

## Known issues

1. If there are multiple worker groups scheduled on one node, this will result in workers using the same CUDA devices and thus leading to a crash. Therefore, it's best to either use 1 GPU nodes, or make sure that the number of workers in a group divided by the number of nodes is equal to the number of GPUs on the nodes.
2. Certain models obtained from Hugging Face hub will cause exceptions due to a [bug in DeepSpeed](https://github.com/microsoft/DeepSpeed/issues/3084). The solution is to reshard the checkpoints of those models to ensure that all layers are stored in contiguous files. The relevant code is included in `huggingface_utils.py`.

## Environment

Key packages:
```
accelerate==0.17.1
deepspeed==0.8.3
ray @ https://s3-us-west-2.amazonaws.com/ray-wheels/latest/ray-3.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl
torch==2.0.0
transformers==4.27.2
```

All packages:
```
absl-py==1.4.0
accelerate==0.17.1
adal==1.2.7
aim==3.16.1
aim-ui==3.16.1
aimrecords==0.0.7
aimrocks==0.3.1
aiofiles==22.1.0
aiohttp==3.8.4
aiohttp-cors==0.7.0
aiorwlock==1.3.0
aiosignal==1.3.1
aiosqlite==0.18.0
ale-py==0.8.1
alembic==1.10.2
anyio==3.6.2
anyscale @ file:///home/ray/anyscale-0.0.0.dev0.tar.gz
anyscale-node-provider @ file:///home/ray/anyscale_node_provider-0.0.1.tar.gz
applicationinsights==0.11.10
argcomplete==1.12.3
argon2-cffi==21.3.0
argon2-cffi-bindings==21.2.0
arrow==1.2.3
asttokens==2.2.1
astunparse==1.6.3
async-timeout==4.0.2
attrs==22.2.0
autocfg==0.0.8
autogluon.common==0.7.0
autogluon.core==0.7.0
autograd==1.5
autopage==0.5.1
AutoROM==0.6.0
AutoROM.accept-rom-license==0.6.0
awscli==1.25.6
awscliv2==2.2.0
ax-platform==0.3.1
azure-cli-core==2.40.0
azure-cli-telemetry==1.0.8
azure-common==1.1.28
azure-core==1.26.3
azure-identity==1.10.0
azure-mgmt-compute==23.1.0
azure-mgmt-core==1.3.2
azure-mgmt-network==19.0.0
azure-mgmt-resource==20.0.0
Babel==2.12.1
backcall==0.2.0
backoff==1.10.0
backports.zoneinfo==0.2.1
base58==2.0.1
bayesian-optimization==1.2.0
bcrypt==4.0.1
beautifulsoup4==4.12.0
bitsandbytes==0.37.2
black==23.1.0
bleach==6.0.0
blessed==1.20.0
blobfile==2.0.1
boto3==1.26.95
botocore==1.29.95
botorch==0.8.3
cached-property==1.5.2
cachetools==5.3.0
catboost==1.1.1
certifi==2022.12.7
cffi @ file:///tmp/abs_98z5h56wf8/croots/recipe/cffi_1659598650955/work
chardet==5.1.0
charset-normalizer==3.1.0
chess==1.7.0
chex==0.1.6
click==8.1.3
cliff==4.2.0
cloudpickle==2.2.1
cma==2.7.0
cmaes==0.9.1
cmake==3.26.0
cmd2==2.4.3
colorama==0.4.6
coloredlogs==15.0.1
colorful==0.5.5
colorlog==6.7.0
comet-ml==3.31.9
comm==0.1.2
commonmark==0.9.1
conda==23.1.0
conda-content-trust @ file:///tmp/abs_5952f1c8-355c-4855-ad2e-538535021ba5h26t22e5/croots/recipe/conda-content-trust_1658126371814/work
conda-package-handling @ file:///croot/conda-package-handling_1666940373510/work
configobj==5.0.8
ConfigSpace==0.4.18
contourpy==1.0.7
coolname==2.2.0
cryptography @ file:///croot/cryptography_1673298753778/work
cycler==0.11.0
Cython==0.29.32
databricks-cli==0.17.5
DataProperty==0.55.0
datasets==2.10.1
debugpy==1.6.6
decorator==5.1.1
decord==0.6.0
deepspeed==0.8.3
defusedxml==0.7.1
Deprecated==1.2.13
diffusers @ git+https://github.com/huggingface/diffusers.git@7fe88613fa15d230d59482889c440c7befa17c25
dill==0.3.6
distlib==0.3.6
dm-tree==0.1.8
docker==6.0.1
docker-pycreds==0.4.0
docutils==0.16
dopamine-rl==4.0.5
dragonfly-opt==0.1.6
dulwich==0.21.3
einops==0.3.0
entrypoints==0.4
etils==1.1.1
evaluate==0.4.0
everett==3.1.0
exceptiongroup==1.1.1
executing==1.2.0
executor==23.2
expiringdict==1.2.2
fastapi==0.95.0
fasteners==0.18
fastjsonschema==2.16.3
filelock==3.10.0
FLAML==1.1.1
Flask==2.2.3
flatbuffers==2.0.7
flax==0.6.7
fonttools==4.39.2
fqdn==1.5.1
freezegun==1.1.0
frozenlist==1.3.3
fsspec==2023.3.0
ftfy==6.1.1
future==0.18.3
gast==0.4.0
gin-config==0.5.0
gitdb==4.0.10
GitPython==3.1.31
glfw==2.5.7
gluoncv==0.10.1.post0
google-api-core==2.11.0
google-api-python-client==1.7.8
google-auth==2.16.2
google-auth-httplib2==0.1.0
google-auth-oauthlib==0.4.6
google-cloud-compute==1.10.1
google-cloud-core==2.3.2
google-cloud-resource-manager==1.9.0
google-cloud-secret-manager==2.16.0
google-cloud-storage==2.7.0
google-crc32c==1.5.0
google-oauth==1.0.1
google-pasta==0.2.0
google-resumable-media==2.4.1
googleapis-common-protos==1.58.0
gpustat==1.0.0
GPy==1.10.0
gpytorch==1.9.1
graphviz==0.8.4
greenlet==2.0.2
grpc-google-iam-v1==0.12.6
grpcio==1.51.3
grpcio-status==1.48.2
grpcio-tools==1.51.3
gunicorn==20.1.0
gym==0.26.2
gym-notices==0.0.8
Gymnasium==0.26.3
gymnasium-notices==0.0.1
h11==0.14.0
h5py==3.7.0
halo==0.0.31
HEBO==0.3.2
higher==0.2.1
hjson==3.1.0
hpbandster==0.7.4
httplib2==0.21.0
huggingface-hub==0.13.3
humanfriendly==10.0
humanize==4.6.0
hyperopt==0.2.5
idna==3.4
imageio==2.26.1
imageio-ffmpeg==0.4.5
importlib-metadata==6.1.0
importlib-resources==5.12.0
iniconfig==2.0.0
ipykernel==6.22.0
ipython==8.11.0
ipython-genutils==0.2.0
ipywidgets==8.0.4
isodate==0.6.1
isoduration==20.11.0
isort==5.12.0
itsdangerous==2.1.2
jax==0.4.6
jaxlib==0.4.6
jedi==0.18.2
Jinja2==3.1.2
jmespath==0.10.0
joblib==1.2.0
json5==0.9.11
jsonlines==3.1.0
jsonpatch==1.32
jsonpointer==2.3
jsonschema==4.17.3
jupyter-events==0.6.3
jupyter-ydoc==0.2.3
jupyter_client==8.1.0
jupyter_core==5.3.0
jupyter_server==2.5.0
jupyter_server_fileid==0.8.0
jupyter_server_terminals==0.4.4
jupyter_server_ydoc==0.6.1
jupyterlab==3.6.1
jupyterlab-pygments==0.2.2
jupyterlab-widgets==3.0.5
jupyterlab_server==2.20.0
kaggle-environments==1.7.11
keras==2.11.0
kiwisolver==1.4.4
knack==0.10.1
kubernetes==26.1.0
lazy_loader==0.1
libclang==15.0.6.1
libtorrent==2.0.7
lightgbm==3.3.5
lightgbm-ray==0.1.8
lightning-bolts==0.4.0
lightning-utilities==0.8.0
linear-operator==0.3.0
lit==16.0.0
lm-dataformat @ git+https://github.com/EleutherAI/lm_dataformat.git@4eec05349977071bf67fc072290b95e31c8dd836
lm-eval==0.3.0
log-symbols==0.0.14
lxml==4.9.2
lz4==4.3.2
Mako==1.2.4
Markdown==3.4.1
markdown-it-py==2.2.0
MarkupSafe==2.1.2
matplotlib==3.7.1
matplotlib-inline==0.1.6
mbstrdecoder==1.1.2
mdurl==0.1.2
minigrid==2.1.1
mistune==2.0.5
mlagents-envs==0.28.0
mlflow==1.30.0
modin==0.18.1
monotonic==1.6
mosaicml==0.12.1
mpmath==1.3.0
msal==1.18.0b1
msal-extensions==1.0.0
msgpack==1.0.5
msrest==0.7.1
msrestazure==0.6.4
mujoco==2.2.0
mujoco-py==2.1.2.14
multidict==6.0.4
multipledispatch==0.6.0
multiprocess==0.70.14
mxnet==1.8.0.post0
mypy-extensions==1.0.0
nbclassic==0.5.3
nbclient==0.7.2
nbconvert==7.2.10
nbformat==5.8.0
nest-asyncio==1.5.6
netifaces==0.11.0
networkx==3.0
nevergrad==0.4.3.post7
ninja==1.11.1
nltk==3.8.1
notebook==6.5.3
notebook_shim==0.2.2
numexpr==2.8.4
numpy==1.23.5
nvidia-cublas-cu11==11.10.3.66
nvidia-cuda-cupti-cu11==11.7.101
nvidia-cuda-nvrtc-cu11==11.7.99
nvidia-cuda-runtime-cu11==11.7.99
nvidia-cudnn-cu11==8.5.0.96
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.2.10.91
nvidia-cusolver-cu11==11.4.0.1
nvidia-cusparse-cu11==11.7.4.91
nvidia-ml-py==11.495.46
nvidia-nccl-cu11==2.14.3
nvidia-nvtx-cu11==11.7.91
oauth2client==4.1.3
oauthlib==3.2.2
onnx==1.12.0
onnxruntime==1.14.1
open-spiel==1.2
openai==0.27.2
opencensus==0.11.2
opencensus-context==0.1.3
opencv-python==4.7.0.72
opentelemetry-api==1.1.0
opentelemetry-exporter-otlp==1.1.0
opentelemetry-exporter-otlp-proto-grpc==1.1.0
opentelemetry-exporter-otlp-proto-http==1.16.0
opentelemetry-proto==1.1.0
opentelemetry-sdk==1.1.0
opentelemetry-semantic-conventions==0.20b0
opt-einsum==3.3.0
optax==0.1.4
optuna==2.10.0
orbax==0.1.5
packaging==23.0
pandas==1.5.3
pandocfilters==1.5.0
paramiko==2.12.0
paramz==0.9.5
parso==0.8.3
pathspec==0.11.1
pathtools==0.1.2
pathvalidate==2.5.2
patsy==0.5.3
pbr==5.11.1
PettingZoo==1.22.1
pexpect==4.8.0
pickleshare==0.7.5
Pillow==9.4.0
pkginfo==1.9.6
pkgutil_resolve_name==1.3.10
platformdirs==3.1.1
plotly==5.13.1
pluggy @ file:///tmp/build/80754af9/pluggy_1648042571233/work
portalocker==2.7.0
prettytable==3.6.0
prometheus-client==0.13.1
prometheus-flask-exporter==0.22.3
promise==2.3
prompt-toolkit==3.0.38
property-manager==3.0
proto-plus==1.22.2
protobuf==3.20.3
psutil==5.9.4
ptyprocess==0.7.0
pure-eval==0.2.2
py-cpuinfo==9.0.0
py-spy==0.3.14
py3nvml==0.2.7
pyaml==21.10.1
pyarrow==11.0.0
pyasn1==0.4.8
pyasn1-modules==0.2.8
pybind11==2.6.2
pycosat @ file:///croot/pycosat_1666805502580/work
pycountry==22.3.5
pycparser @ file:///tmp/build/80754af9/pycparser_1636541352034/work
pycryptodomex==3.17
pydantic==1.10.6
pyDeprecate==0.3.2
pygame==2.1.2
pyglet==1.5.15
Pygments==2.14.0
PyJWT==2.6.0
pymoo==0.5.0
pymunk==6.2.1
PyNaCl==1.5.0
PyOpenGL==3.1.6
pyOpenSSL==23.0.0
pyparsing==3.0.9
pyperclip==1.8.2
pypng==0.20220715.0
pyro-api==0.1.2
pyro-ppl==1.8.4
Pyro4==4.82
pyrsistent==0.19.3
PySocks @ file:///tmp/build/80754af9/pysocks_1605305779399/work
pytablewriter==0.64.2
pytest==7.2.2
pytest-remotedata==0.3.2
python-dateutil==2.8.2
python-json-logger==2.0.7
pytorch-lightning==2.0.0
pytorch-ranger==0.1.1
pytz==2022.7.1
pytz-deprecation-shim==0.1.0.post0
PyWavelets==1.4.1
PyYAML==6.0
pyzmq==25.0.2
querystring-parser==1.2.4
ray @ file:///home/ray/ray-3.0.0.dev0-cp38-cp38-manylinux2014_x86_64.whl
ray-lightning==0.3.0
recsim==0.2.4
redis==3.5.3
regex==2022.10.31
requests==2.28.2
requests-oauthlib==1.3.1
requests-toolbelt==0.10.1
responses==0.18.0
RestrictedPython==6.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rich==12.0.1
rouge-score==0.1.2
rsa==4.9
ruamel.yaml @ file:///croot/ruamel.yaml_1666304550667/work
ruamel.yaml.clib @ file:///croot/ruamel.yaml.clib_1666302247304/work
s3transfer==0.6.0
sacrebleu==1.5.0
scikit-image==0.20.0
scikit-learn==1.2.2
scikit-optimize==0.9.0
scipy==1.10.1
segment-analytics-python==2.2.2
semantic-version==2.10.0
Send2Trash==1.8.0
sentencepiece==0.1.96
sentry-sdk==1.17.0
serpent==1.41
setproctitle==1.3.2
shortuuid==1.0.1
sigopt==7.5.0
six==1.16.0
smart-open==6.3.0
smmap==5.0.0
sniffio==1.3.0
soupsieve==2.4
spinners==0.0.24
SQLAlchemy==1.4.47
sqlitedict==2.1.0
sqlparse==0.4.3
stack-data==0.6.2
starlette==0.26.1
statsmodels==0.13.5
stevedore==5.0.0
SuperSuit==3.7.0
sympy==1.11.1
tabledata==1.3.1
tabulate==0.9.0
tblib==1.7.0
tcolorpy==0.1.2
tenacity==8.2.2
tensorboard==2.12.0
tensorboard-data-server==0.7.0
tensorboard-plugin-wit==1.8.1
tensorboardX==2.4.1
tensorflow-estimator==2.11.0
tensorflow-io-gcs-filesystem==0.31.0
tensorflow-probability==0.19.0
tensorstore==0.1.33
termcolor==2.2.0
terminado==0.10.1
tf-slim==1.1.0
tf2onnx==1.13.0
threadpoolctl==3.1.0
tifffile==2023.3.15
tiktoken==0.1.2
timm==0.4.5
tinycss2==1.2.1
tinyscaler==1.2.5
tokenizers==0.13.2
tomli==2.0.1
toolz @ file:///croot/toolz_1667464077321/work
torch==2.0.0
torch-optimizer==0.3.0
torchaudio==2.0.1
torchmetrics==0.11.4
torchvision==0.15.1
tornado==6.2
tqdm==4.65.0
tqdm-multiprocess==0.0.11
traitlets==5.9.0
transformers==4.27.2
triton==2.0.0
tune-sklearn==0.4.4
typeguard==2.13.3
typepy==1.3.0
typer==0.6.1
typing_extensions==4.5.0
tzdata==2022.7
tzlocal==4.3
ujson==5.7.0
uri-template==1.2.0
uritemplate==3.0.1
urllib3==1.26.15
uvicorn==0.21.1
verboselogs==1.7
virtualenv==20.21.0
wandb==0.13.4
wcwidth==0.2.6
webcolors==1.12
webencodings==0.5.1
websocket-client==1.5.1
Werkzeug==2.2.3
widgetsnbextension==4.0.5
wrapt==1.15.0
wurlitzer==3.0.3
xgboost==1.7.4
xgboost-ray==0.1.15
xmltodict==0.13.0
xxhash==3.2.0
y-py==0.5.9
yacs==0.1.8
yarl==1.8.2
ypy-websocket==0.8.2
zipp==3.15.0
zoopt==0.4.1
zstandard==0.20.0
```