https://github.com/canreader/mnistpy
This is my first AI model which uses tensorflo/keras
https://github.com/canreader/mnistpy
Last synced: about 1 month ago
JSON representation
This is my first AI model which uses tensorflo/keras
- Host: GitHub
- URL: https://github.com/canreader/mnistpy
- Owner: CanReader
- Created: 2021-12-14T19:43:17.000Z (over 4 years ago)
- Default Branch: main
- Last Pushed: 2021-12-14T19:43:50.000Z (over 4 years ago)
- Last Synced: 2023-09-14T11:46:21.546Z (almost 3 years ago)
- Language: Python
- Size: 1.95 KB
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# MNIST Digit Classifier
A professional PyTorch implementation of a Convolutional Neural Network (CNN) for
handwritten digit classification on the [MNIST dataset](http://yann.lecun.com/exdb/mnist/).
## Architecture
| Layer | Output Shape | Parameters |
|-------|-------------|------------|
| Input | 1 x 28 x 28 | - |
| Conv2d(1, 32, 3) + BN + ReLU + MaxPool | 32 x 14 x 14 | 416 |
| Conv2d(32, 64, 3) + BN + MaxPool | 64 x 7 x 7 | 18,624 |
| Flatten + Linear(3136, 128) + ReLU | 128 | 401,536 |
| Dropout(0.25) + Linear(128, 10) | 10 | 1,290 |
| **Total** | | **~422K** |
## Features
- **CNN architecture** with batch normalisation and dropout
- **Data augmentation** (random rotation, translation) to improve generalisation
- **Train / Validation / Test split** (54K / 6K / 10K)
- **Early stopping** on validation loss with best-model checkpointing
- **Learning rate scheduling** (StepLR)
- **Full evaluation report** with per-class precision, recall, and F1
- **Visualisations** — training curves, confusion matrix, sample predictions
- **Single-image inference** script for custom images
## Quick Start
```bash
# 1. Install dependencies
pip install -r requirements.txt
# 2. Train and evaluate
python main.py
# 3. Predict on a custom image
python predict.py path/to/digit.png
```
## CLI Options
```
python main.py [OPTIONS]
--epochs N Number of training epochs (default: 15)
--batch-size N Batch size (default: 128)
--lr F Learning rate (default: 0.001)
--eval-only Skip training, evaluate a saved checkpoint
```
## Project Structure
```
MnistPY/
├── main.py # Entry point — train, evaluate, visualise
├── config.py # All hyperparameters and paths
├── model.py # CNN architecture (MNISTNet)
├── dataset.py # Data loading, transforms, splits
├── train.py # Training loop with validation
├── evaluate.py # Metrics and classification report
├── predict.py # Single-image inference CLI
├── visualize.py # Plotting utilities
├── requirements.txt # Python dependencies
└── outputs/ # Generated after training
├── best_model.pth
├── training_curves.png
├── confusion_matrix.png
└── sample_predictions.png
```
## Expected Results
With default hyperparameters (~15 epochs), the model reaches **~99.2% test accuracy**.