https://github.com/mrowan137/mnist-from-scratch-hip
From-scratch implementation of MNIST linear classifier in HIP, tuned for Radeon Pro VII (gfx906)
https://github.com/mrowan137/mnist-from-scratch-hip
amd-gpu cpp hip machine-learning mnist rocm
Last synced: about 7 hours ago
JSON representation
From-scratch implementation of MNIST linear classifier in HIP, tuned for Radeon Pro VII (gfx906)
- Host: GitHub
- URL: https://github.com/mrowan137/mnist-from-scratch-hip
- Owner: mrowan137
- License: mit
- Created: 2026-04-12T23:28:16.000Z (3 months ago)
- Default Branch: main
- Last Pushed: 2026-06-10T06:51:54.000Z (25 days ago)
- Last Synced: 2026-07-05T00:30:27.722Z (about 7 hours ago)
- Topics: amd-gpu, cpp, hip, machine-learning, mnist, rocm
- Language: C++
- Homepage:
- Size: 11.1 MB
- Stars: 2
- Watchers: 0
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# MNIST from scratch in HIP
Digit classification with the MNIST dataset is used widely as a 'Hello World' problem in machine learning.
In the age of capable ML frameworks and abstractions, it is easier than ever to produce a working MNIST classifier.
While powerful, these can obscure the underlying mechanisms.
This small project aims to restore some of that view by implementing a simple MNIST linear classification model from scratch in HIP, tuned for Radeon Pro VII (gfx906).
MNIST test dataset visualization and predictions (300 training iterations, learning rate 0.01).
## Requirements
- AMD GPU (code is tuned for Radeon Pro VII / gfx906)
- Compatible ROCm installation (see: https://rocm.docs.amd.com/en/latest/compatibility/compatibility-matrix.html)
## Build and run
```
make
./mnist ./dataset [iterations] [learning rate]
```
The optional arguments default to 100 iterations and learning rate of 0.01.
## Model
Let $X^{jk}$ denote images (training or test), where $j$ indexes the image, and $k$ indexes pixels (plus bias).
Let $W^{ik}$ denote model weights, where $i$ indexes digit class (0–9).
Repeated indices are implicitly summed and unique indices denote free matrix component labels.
We use all-up convention, so up/down indices carry no geometric significance.
Class scores are given by:
$$
\hat{Y}^{ij} = W^{ik} X^{jk},
$$
and loss by the average squared residual:
$$
L := \frac{1}{N}\lVert\hat{Y}^{ij} - Y_{\mathrm{true}}^{ij}\rVert_{\mathrm{F}}^{2},
$$
where $N$ is the number of images. Training proceeds by iterative update of the weights via gradient descent:
$$
\begin{aligned}
W^{ik} &\leftarrow W^{ik} - \Delta \frac{\partial L }{\partial W^{ik}}\\
&= W^{ik} - \frac{2 \Delta}{N} ( \hat{Y}^{ij} - Y_{\mathrm{true}}^{ij}) X^{jk},
\end{aligned}
$$
where $\Delta$ is the learning rate.
## Training loss and accuracy
On a Radeon Pro VII test system with ROCm 6.4.3, this implementation achieved 84.8% training / 85.5% test accuracy in 300 iterations using learning rate 0.01, 4.18 ms / iteration.
MNIST training loss over 300 iterations (learning rate 0.01).
MNIST training accuracy over 300 iterations (learning rate 0.01).
## Optimizations
The majority of time in the training loop is spent in `sgemm_sub_and_scale` and `sgemm`, which roughly account for forward pass and back propagation, respectively.
The primary mathematical operation in each of these steps is a matrix multiplication.
Key performance considerations and optimizations incorporated across the training kernels are:
- data layout: format matrices to be amenable to coalesced memory accesses (e.g. transpose matrices)
- ping-pong buffering: overlap load of $k+1^{th}$ tile with compute of $k^{th}$ tile
- shared memory tiling: compute partial dot products in a way that optimizes data reuse
- hand-tuned tiling parameters
- occupancy tuning via launch bounds, thread, and block parameters
- loop unrolling (not always a win, but can help if it does not induce register spillage)
- warp shuffle reduction: reduce atomics contention
- kernel fusion: store $2(\hat{Y} - Y)$ during forward pass, in preparation for back propagation
## References
- MNIST dataset:
- https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz
- https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz
- https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz
- https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz
- MNIST handwritten database (archived from Yann LeCun's website): https://web.archive.org/web/20200430193701/http://yann.lecun.com/exdb/mnist/
- Reading MNIST dataset: https://stackoverflow.com/questions/8286668/how-to-read-mnist-data-in-c
- Character ramp: https://gist.github.com/micycle1/507c3052a9fcf04520430440d0671ecb
- Random number initializations: https://github.com/joelkp/ranoise/blob/main/splitmix32.c
- Neural networks from scratch: https://karpathy.ai/zero-to-hero.html
- Radeon Pro VII specs: https://www.techpowerup.com/gpu-specs/radeon-pro-vii.c3575
- ROCm compatibility matrix: https://rocm.docs.amd.com/en/latest/compatibility/compatibility-matrix.html
- GEMM by hand: https://github.com/AyakaGEMM/Hands-on-GEMM/tree/main
## License
This project is licensed under the [MIT License](./LICENSE).