An open API service indexing awesome lists of open source software.

https://github.com/canreader/mnistpy

This is my first AI model which uses tensorflo/keras
https://github.com/canreader/mnistpy

Last synced: about 1 month ago
JSON representation

This is my first AI model which uses tensorflo/keras

Awesome Lists containing this project

README

          

# MNIST Digit Classifier

A professional PyTorch implementation of a Convolutional Neural Network (CNN) for
handwritten digit classification on the [MNIST dataset](http://yann.lecun.com/exdb/mnist/).

## Architecture

| Layer | Output Shape | Parameters |
|-------|-------------|------------|
| Input | 1 x 28 x 28 | - |
| Conv2d(1, 32, 3) + BN + ReLU + MaxPool | 32 x 14 x 14 | 416 |
| Conv2d(32, 64, 3) + BN + MaxPool | 64 x 7 x 7 | 18,624 |
| Flatten + Linear(3136, 128) + ReLU | 128 | 401,536 |
| Dropout(0.25) + Linear(128, 10) | 10 | 1,290 |
| **Total** | | **~422K** |

## Features

- **CNN architecture** with batch normalisation and dropout
- **Data augmentation** (random rotation, translation) to improve generalisation
- **Train / Validation / Test split** (54K / 6K / 10K)
- **Early stopping** on validation loss with best-model checkpointing
- **Learning rate scheduling** (StepLR)
- **Full evaluation report** with per-class precision, recall, and F1
- **Visualisations** — training curves, confusion matrix, sample predictions
- **Single-image inference** script for custom images

## Quick Start

```bash
# 1. Install dependencies
pip install -r requirements.txt

# 2. Train and evaluate
python main.py

# 3. Predict on a custom image
python predict.py path/to/digit.png
```

## CLI Options

```
python main.py [OPTIONS]

--epochs N Number of training epochs (default: 15)
--batch-size N Batch size (default: 128)
--lr F Learning rate (default: 0.001)
--eval-only Skip training, evaluate a saved checkpoint
```

## Project Structure

```
MnistPY/
├── main.py # Entry point — train, evaluate, visualise
├── config.py # All hyperparameters and paths
├── model.py # CNN architecture (MNISTNet)
├── dataset.py # Data loading, transforms, splits
├── train.py # Training loop with validation
├── evaluate.py # Metrics and classification report
├── predict.py # Single-image inference CLI
├── visualize.py # Plotting utilities
├── requirements.txt # Python dependencies
└── outputs/ # Generated after training
├── best_model.pth
├── training_curves.png
├── confusion_matrix.png
└── sample_predictions.png
```

## Expected Results

With default hyperparameters (~15 epochs), the model reaches **~99.2% test accuracy**.