https://github.com/ekzhang/archax
Experiments in multi-architecture parallelism for deep learning with JAX
https://github.com/ekzhang/archax
cpu gpu jax machine-learning ml parallelism pipeline tpu
Last synced: about 1 month ago
JSON representation
Experiments in multi-architecture parallelism for deep learning with JAX
- Host: GitHub
- URL: https://github.com/ekzhang/archax
- Owner: ekzhang
- License: mit
- Created: 2022-09-28T02:13:29.000Z (over 2 years ago)
- Default Branch: main
- Last Pushed: 2022-12-11T02:24:00.000Z (over 2 years ago)
- Last Synced: 2025-01-19T17:11:38.651Z (3 months ago)
- Topics: cpu, gpu, jax, machine-learning, ml, parallelism, pipeline, tpu
- Language: Python
- Homepage:
- Size: 1.42 MB
- Stars: 3
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# archax
**Experiments in multi-architecture parallelism for deep learning with JAX.**

What if we could create a new kind of multi-architecture parallelism library for deep learning compilers, supporting expressive frontends like JAX? This would optimize a mix of pipeline and operator parallelism on accelerated devices. Use both CPU, GPU, and/or TPU in the same program, and automatically interleave between them.
Experiments are given in this repository, dated and annotated with brief descriptions.
## License
All code and notebooks in this repository are distributed under the terms of the [MIT](https://spdx.org/licenses/MIT.html) license.