https://github.com/tgale96/grouped_gemm
PyTorch bindings for CUTLASS grouped GEMM.
https://github.com/tgale96/grouped_gemm
Last synced: 3 months ago
JSON representation
PyTorch bindings for CUTLASS grouped GEMM.
- Host: GitHub
- URL: https://github.com/tgale96/grouped_gemm
- Owner: tgale96
- License: apache-2.0
- Created: 2023-09-19T20:08:53.000Z (over 2 years ago)
- Default Branch: main
- Last Pushed: 2024-10-31T19:40:59.000Z (over 1 year ago)
- Last Synced: 2025-04-29T06:16:57.359Z (12 months ago)
- Language: Cuda
- Size: 43 KB
- Stars: 84
- Watchers: 4
- Forks: 57
- Open Issues: 12
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
- awesome-gemm - CUTLASS-based Grouped GEMM: Efficient grouped GEMM operations - 2.0) (Example Implementations 💡 / Blogs 🖋️)
README
# Grouped GEMM
A lighweight library exposing grouped GEMM kernels in PyTorch.
# Installation
Run `pip install grouped_gemm` to install the package.
# Compiling from source
By default, the installed package runs in conservative (`cuBLAS`) mode:
it launches one GEMM kernel per batch element instead of using a single
grouped GEMM kernel for the whole batch.
To enable using grouped GEMM kernels, you need to switch to the `CUTLASS`
mode by setting the `GROUPED_GEMM_CUTLASS` environment variable to `1`
when building the library. For example, to build the library in `CUTLASS`
mode for Ampere (SM 8.0), clone the repository and run the following:
```bash
$ TORCH_CUDA_ARCH_LIST=8.0 GROUPED_GEMM_CUTLASS=1 pip install .
```
See [this comment](https://github.com/tgale96/grouped_gemm/pull/14#issuecomment-2211362572)
for some performance measurements on A100 and H100.
# Benchmark example
```python
python benchmark.py
```
# Upcoming features
* Hopper-optimized grouped GEMM kernels.