https://github.com/felixsoderstrom/cifar
https://github.com/felixsoderstrom/cifar
Last synced: 10 months ago
JSON representation
- Host: GitHub
- URL: https://github.com/felixsoderstrom/cifar
- Owner: FelixSoderstrom
- Created: 2025-05-20T09:51:45.000Z (11 months ago)
- Default Branch: main
- Last Pushed: 2025-05-21T10:13:41.000Z (11 months ago)
- Last Synced: 2025-06-02T08:19:38.577Z (11 months ago)
- Language: Python
- Size: 26.4 KB
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# CIFAR-10 Image Classification
This project implements deep learning models for image classification on the CIFAR-10 dataset. It provides a complete pipeline for training, evaluating, and visualizing the performance of different CNN architectures.
## Project Overview
The CIFAR-10 dataset consists of 60,000 32x32 color images in 10 classes, with 6,000 images per class. This project offers:
- Training of custom CNN architectures from scratch
- Transfer learning using pre-trained ResNet50
- Comprehensive evaluation metrics and visualizations
- Modular codebase for easy experimentation
## Project Structure
```
CIFAR/
│
├── main.py # Entry point for training and evaluation
├── README.md # This file
│
└── src/
├── data_processing/ # Data loading and augmentation
│ ├── augment.py # Data augmentation functions
│ └── utils.py # Data utilities
│
├── networks/ # Model architectures
│ ├── classic_network.py # Custom CNN architecture
│ ├── transfer_network.py # Transfer learning with ResNet50
│ └── utils.py # Network utilities
│
├── training/ # Training functionality
│ ├── trainer.py # Training loop implementation
│ └── utils.py # Training utilities
│
└── evaluation/ # Evaluation functionality
├── evaluate.py # Model evaluation
├── utils.py # Evaluation utilities
└── visualize.py # Visualization functions
```
## Usage
### Training a Model
To train a model from scratch:
```bash
python main.py --model classic --epochs 30 --batch_size 128 --lr 0.001 --gpu
```
To train using transfer learning with ResNet50:
```bash
python main.py --model transfer --epochs 20 --batch_size 64 --lr 0.0001 --gpu
```
### Command Line Arguments
- `--model`: Model architecture to use (`classic` or `transfer`)
- `--epochs`: Number of training epochs (default: 30)
- `--batch_size`: Batch size for training (default: 128)
- `--lr`: Learning rate (default: 0.001)
- `--weight_decay`: Weight decay for optimizer (default: 1e-4)
- `--seed`: Random seed (default: 42)
- `--gpu`: Use GPU if available (flag)
- `--evaluate_only`: Only run evaluation on a trained model (flag)
### Evaluation Only
To evaluate a trained model without retraining:
```bash
python main.py --model classic --evaluate_only --gpu
```
## Features
### Data Augmentation
The project implements several data augmentation techniques:
- Random cropping
- Random horizontal flips
- Random rotation
- Color jitter
### Model Architectures
1. **ClassicCNN**: A custom CNN architecture with:
- 4 convolutional blocks with increasing filter sizes
- Batch normalization
- Max pooling
- Dropout for regularization
- Fully connected layers
2. **TransferResNet50**: A transfer learning approach using:
- Pre-trained ResNet50 as feature extractor
- Custom classification head for CIFAR-10
### Evaluation Metrics
- Accuracy (overall and per-class)
- Precision, recall, and F1 score
- Confusion matrix
- Feature embeddings visualization (t-SNE and PCA)
- Visualization of misclassified samples
- Training and validation curves
## Output
The results are saved in an `output/session_X` directory, where `X` is the session number. Each session directory contains:
- `checkpoints/`: Model weights for each epoch and the best model
- `plots/`: Visualization plots (confusion matrices, embeddings, etc.)
- `test_results_*.txt`: Detailed evaluation metrics
- `training_summary.txt`: Summary of the training process
- `stats_*.json`: Training statistics for plotting
## Requirements
- Python 3.6+
- PyTorch
- torchvision
- numpy
- matplotlib
- scikit-learn
- tqdm
- pytorch-lightning
## License
[MIT License](LICENSE)