https://github.com/mpieper19/machine-learning-model-comparision-with-cifar10
School Project for our "Intro to Machine Learning" course. Was done in collaboration with two other students. This project compares the accuracy and performance of standardized machine learning models (e.g. KNN, LightGBM, etc.) against our built CNN, on the CIFAR-10 Dataset.
https://github.com/mpieper19/machine-learning-model-comparision-with-cifar10
cifar10 cnn-classification convolutional-neural-networks image-classification keras keras-tensorflow machine-learning scikitlearn-machine-learning tensorflow
Last synced: 4 days ago
JSON representation
School Project for our "Intro to Machine Learning" course. Was done in collaboration with two other students. This project compares the accuracy and performance of standardized machine learning models (e.g. KNN, LightGBM, etc.) against our built CNN, on the CIFAR-10 Dataset.
- Host: GitHub
- URL: https://github.com/mpieper19/machine-learning-model-comparision-with-cifar10
- Owner: mpieper19
- License: mit
- Created: 2024-12-26T13:25:46.000Z (over 1 year ago)
- Default Branch: main
- Last Pushed: 2025-02-07T10:34:29.000Z (over 1 year ago)
- Last Synced: 2025-09-11T22:48:31.190Z (8 months ago)
- Topics: cifar10, cnn-classification, convolutional-neural-networks, image-classification, keras, keras-tensorflow, machine-learning, scikitlearn-machine-learning, tensorflow
- Language: Python
- Homepage:
- Size: 1.93 MB
- Stars: 1
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Machine-Learning-model-comparision-with-CIFAR10 (Uni Project)
Intro to Machine Learning WBAI056
## Table of Contents
- [About the Proejct](#about-the-project)
- [Built With](#built-with)
- [Usage](#usage)
- [Results](#results)
## About the Project
This project was conducted with two other students, where we comapre the accuracy and effectiveness of a (custom-built) Convolutional Neural Network (CNN), to a handful of other standard machine learning models, on the CIFAR-10 dataset. The following machine learning models are used:
- K-Nearest-Neighbors (KNN)
- Light-Gradient-Boosting-Model (LightGBM)
- CatBoost
- Random Forest Classifier
## Built With
- Python
- Tensorflow
- Keras
- Scikit Learn
- Numpy
- Pandas
- Matplotlib
- Seaborn
- Catboost
- LightGBM
Use `pip install -r requirements.txt` to install required dependencies.
## Usage
To test and run each model, configure the `main.py` file according to the desired model:
### For the CNN Model:
- Use the `train_and_evaluate_CNN()` function to train and evaluate the CNN model.
- Ensure the following:
- One-hot encoding for labels is enabled (`one_hot=True`).
- Image flattening is disabled (`flatten=False`), as the CNN processes structured image data.
- Results, including training/validation loss and accuracy curves, ROC/AUC curves, the model object, and classification reports, will be saved in the `results` directory.
Run the CNN model with the following code:
```python
accuracy, results = train_and_evaluate_CNN()
```
### For Classic Machine Learning Models:
- Use the `train_and_evaluate_model(model_code, model_name)` function for classic models. `model_code` and `model_name` can be retrieved from the `_model_map` dictionary found in `models\__init__.py`.
```python
_model_map = {
'cnn': CNNModel,
'knn': KNNModel,
'lgbm': LGBM,
'cat': CatBoosst,
'forest': RFCModel
}
```
- Ensure the following:
- Set `flattenn=True` to prepare the images as 1D vectors, as required by classic ML models.
- Results, including confusion matrices, ROC/AUC curves, the model object, and classification reports, will be saved in the `results` directory.
Run the CNN model with the following code:
```python
accuracy, results = train_and_evaluate_model("knn", "KNN")
```
For example, to train the Random Forest model:
```python
accuracy, results = train_and_evaluate_model("forest", "RandomForest")
```
## Results
The classification reports for the models are saved as CSV files in the `results/reports/` directory. For example:
- [CNN Classification Report](results/reports/CNN_classification_report.csv)
- [KNN Classification Report](results/reports/KNN_classification_report.csv)
- [Random Forest Classifier Classification Report](results/reports/randomforest_classification_report.csv)
- [LightGBM Classification Report](results/reports/LightGBM_classification_report.csv)
- [CatBoost Classification Report](results/reports/CatBoost_classification_report.csv)
### CNN Results:
CNN Accuracy: **70%**
CNN Training and Validation Loss Plot:

CNN Training and Validation Accuracy Plot:

CNN ROC/AUC Curve:

### KNN Results:
KNN Accuracy: **32%**
KNN Confusion Matrix:

KNN ROC/AUC Curve:

### Random Forest Classifier Results:
Random Forest Classifier accuracy: **46%**
Random Forest Classifer Confusion Matrix:

Random Forest Classifier ROC/AUC Curve:

### LightGBM Results:
LightGBM Accuracy: **53%**
LightGBM Loss Plot:

LightGBM Confusion Matrix:

LightGBM ROC/AUC Curve:

### CatBoost Results:
CatBoost Accuracy: **58%**
Catboost Loss Plot:

CatBoost Confusionn Matrix:

CatBoost ROC/AUC Curve:

## License
This project is licensed under the MIT License.