Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/jundaf2/cuda-int8-gemm
CUDA 8-bit Tensor Core Matrix Multiplication based on m16n16k16 WMMA API
https://github.com/jundaf2/cuda-int8-gemm
Last synced: 2 months ago
JSON representation
CUDA 8-bit Tensor Core Matrix Multiplication based on m16n16k16 WMMA API
- Host: GitHub
- URL: https://github.com/jundaf2/cuda-int8-gemm
- Owner: jundaf2
- Created: 2023-06-15T13:46:24.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2023-09-15T18:38:30.000Z (over 1 year ago)
- Last Synced: 2024-08-04T02:06:39.094Z (6 months ago)
- Language: Cuda
- Homepage:
- Size: 4.29 MB
- Stars: 19
- Watchers: 3
- Forks: 2
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# CUDA-INT8-GEMM
## Introduction
The 8-bit GEMM takes two 8-bit input matrices and produces an output matrix which is also of 8-bit.C = A*B^T
We adopt the same convention as the cuBLAS library, where the matrices are stored in column-major order. `GEMM_OP_T` means the matrix is transposed in column-major representation, which is equivalent to the non-transposed matrix in row-major representation. `GEMM_OP_N` means the matrix is not transposed in column-major representation, which is equivalent to the transposed matrix in row-major representation. The same convention applies to matrix C.
You may undersand the `T` and `N` in these flags as either `transpose` / `non-transpose` operation for col-major BLAS (Fortran) matrices or `true` / `not true` for row-major C/C++ matrices.
## the 8-bit WMMA Tensor Core API with Shape m16n16k16
Since there is no single PTX instruction to perform a m16n16k16 8-bit matrix multiplicaiton, we think the buildin intrinsic `__imma_m16n16k16_mma_s8` is composed of 4 `mma.sync.aligned.m8n8k16.row.s32.s8.s8.s32` instructions. The following figure shows how the four 8-bit m8n8k16 instructions resulting in one m16n16k16 buildin intrinsic. For simplicity without much consideration for the performance in this example, we will use `cp.async.ca.shared.global` to load the data from global memory to shared memory asynchronously. `wmma::load_matrix_sync` will load the data from shared memory to register. `wmma::mma_sync` will perform the matrix multiplication.For the detailed register data layout of the WMMA 8-bit m16n16k16 API, please see the following figure
## Current feature
The output is also of type `int8`. For example, when you use GEMM in a 8-bit framework, you may want to use `int8` output as the input of next layer's operation in spite of the fact that the tensor core itself uses `int32` as accumalator.
Performance is quite poor due to
* unsolved bank conflict when loading the data from shared memory to register
* unoptimized global memory writeCurrently, you can try different size of matrix multiplication with the following cmd (potentially you need to tune the block size and grid size in the code):
```
./test_gemm_i8 1024 1024 1024 1 0 1 1
```