https://github.com/ryanrudes/cs148a
An overpowered digit classifier built for EE/CNS/CS 148A @ Caltech. Make digits out of anything.
https://github.com/ryanrudes/cs148a
digit-classification
Last synced: 5 days ago
JSON representation
An overpowered digit classifier built for EE/CNS/CS 148A @ Caltech. Make digits out of anything.
- Host: GitHub
- URL: https://github.com/ryanrudes/cs148a
- Owner: ryanrudes
- Created: 2026-02-06T01:39:06.000Z (4 months ago)
- Default Branch: main
- Last Pushed: 2026-03-23T02:02:49.000Z (3 months ago)
- Last Synced: 2026-05-31T18:33:11.960Z (16 days ago)
- Topics: digit-classification
- Language: Python
- Homepage:
- Size: 65.2 MB
- Stars: 1
- Watchers: 0
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# Digit Classifier
https://github.com/user-attachments/assets/e00540d3-96e7-4360-9dc5-eb5e3a0b5e29
https://github.com/user-attachments/assets/53767d51-1e46-442c-ae86-29178b18e6b4
https://github.com/user-attachments/assets/7913d9f3-2399-483e-9ce6-53ac047e823b
https://github.com/user-attachments/assets/b402ce11-aa25-47d0-b2e9-c745f2e88d0c
A PyTorch training pipeline for digit classification using **ResNeXt** with
YOLO-style augmentation, external dataset mixing, and comprehensive experiment
tracking.
## Features
- **ResNeXt-101** with stochastic depth (drop-path) and grouped convolutions
- **YOLO-style augmentation** pipeline with digit-safe hyper-parameters (label-conditional horizontal flip for symmetric digits 0 and 8)
- **External dataset mixing** — SVHN, MNIST, EMNIST, USPS, QMNIST and Semeion are lazily loaded and mixed into training via a `RatioBatchSampler` (95 % original / 5 % external per batch)
- **NIST-like deduplication** — SHA-1 fingerprints remove exact pixel duplicates across external datasets (cached as a tiny JSON so dedup only runs once)
- **Mixup / CutMix** (element-wise mode) with dynamic loss switching — `SoftTargetCrossEntropy` while active, plain `CrossEntropyLoss` when disabled for the final *N* epochs
- **EMA model** with `use_buffers=True`
- **Warm-restart scheduler** — linear warmup followed by `CosineAnnealingWarmRestarts` with pre-restart checkpoints
- **AMP** with `GradScaler` (CUDA-only) and gradient norm clipping
- **Weights & Biases** integration for experiment tracking and artifact storage
- **Rich** console output — tables, progress bars and structured logging
- **HuggingFace Hub** integration for pushing/pulling dataset caches
- **Webcam inference** with real-time probability visualisation
## Quick start (local)
```bash
# Create a virtual environment
python -m venv .venv
source .venv/bin/activate
# Install the package in development mode
pip install -e ".[dev]"
# 1. Download the raw dataset from Google Drive
python -m digit_classifier download
# 2. Preprocess and cache as .npz
python -m digit_classifier preprocess --name mnist_rgb_224 --color --size 224
# 3. Train (runs external dataset download + dedup on first run)
python -m digit_classifier train
# 4. Run webcam inference (uses EMA model; pass --mean/--std if checkpoint lacks them)
python -m digit_classifier infer --checkpoint checkpoints//best.pt \
--mean 0.57 0.52 0.48 --std 0.23 0.23 0.23
# 5. Export compiled pipeline (TorchScript)
This compiles the model _and_ preprocessing into a single TorchScript file
that accepts raw image tensors and can be uploaded to the Hub. The export
uses the **EMA** model from the checkpoint.
**Normalization (mean/std):** Checkpoints do not currently store mean/std.
You must pass `--mean` and `--std` so the exported pipeline matches your
training normalization. Get these from the training log (e.g. `mean=[0.13, 0.13, 0.13]`)
or from your cached dataset. If omitted, 0.5/0.5 is used (likely incorrect).
```bash
# RGB (3 channels) — use the mean/std printed during train
python -m digit_classifier export-pipeline \
--checkpoint checkpoints//best.pt \
--output pipeline-cnn.pt \
--mean 0.13 0.13 0.13 \
--std 0.31 0.31 0.31
# Grayscale (1 channel)
python -m digit_classifier export-pipeline \
--checkpoint checkpoints//best.pt \
--output pipeline-cnn.pt \
--input-channels 1 \
--mean 0.13 \
--std 0.31
# Upload to HuggingFace Hub
python -m digit_classifier export-pipeline \
--checkpoint checkpoints//best.pt \
--output pipeline-cnn.pt \
--mean 0.13 0.13 0.13 --std 0.31 0.31 0.31 \
--push-to-hf --hf-repo /
```
## Cloud training
After running locally at least once (so caches exist), push them to
HuggingFace Hub and pull on any cloud VM:
### Push caches (run locally)
```bash
# Authenticate with HuggingFace (one-time)
huggingface-cli login
# Push internal dataset cache + dedup indices to a private HF repo
python -m digit_classifier push-cache --repo /digit-classification-cache
```
This uploads:
- `datasets/mnist_rgb_224.npz` — preprocessed internal dataset (~1.5 GB)
- `datasets/dedup_indices_*.json` — dedup index cache (a few KB)
### Pull caches and train (run on cloud VM)
```bash
# Install
pip install -e ".[dev]"
# Authenticate with HuggingFace
huggingface-cli login
# Or set the token directly:
# export HF_TOKEN=hf_...
# Pull caches from HuggingFace Hub
python -m digit_classifier pull-cache --repo /digit-classification-cache
# Train — external datasets (SVHN, MNIST, etc.) download automatically
# from torchvision on first access; dedup is skipped (cached indices)
python -m digit_classifier train
```
## CLI reference
| Command | Description |
|---|---|
| `download` | Fetch the raw JPEG archive from Google Drive |
| `preprocess` | Resize, colour-convert, compute mean/std and cache as `.npz` |
| `train` | Run the full training pipeline |
| `infer` | Real-time webcam digit recognition (uses EMA; pass `--mean`/`--std` if checkpoint lacks them) |
| `export-pipeline` | Compile model + preprocessing into a TorchScript pipeline (uses EMA; pass `--mean`/`--std` if checkpoint lacks them) |
| `visualize` | Debug-view augmented + mixed-up training batches |
| `push-cache` | Push dataset caches to a HuggingFace Hub repo |
| `pull-cache` | Pull dataset caches from a HuggingFace Hub repo |
Every training hyper-parameter is exposed as a CLI flag with the current
defaults. Run `python -m digit_classifier train --help` for the full list.
## Project layout
```
src/digit_classifier/
__init__.py Package version
__main__.py CLI entry-point (argparse subcommands)
config.py Dataclass-based configuration
model.py ResNeXt + DropPath
dataset.py DigitDataset (tensor-backed)
splitting.py Train/val split + external mixing + stats
external.py ExternalOnDemandDataset + deduplication
sampler.py RatioBatchSampler
augmentation.py YOLO-style augmentation pipeline
mixup.py timm Mixup/CutMix wrapper
preprocessing.py Download + preprocess + cache
training.py Training loop
inference.py Webcam inference
visualize.py Debug batch visualisation
hub.py HuggingFace Hub push/pull
tests/ pytest test suite
```
## Running tests
```bash
pytest
```