Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/not-ml/ml-3
A PyTorch-based Convolutional Neural Network (CNN) for image classification using the CIFAR-10 dataset, featuring advanced architecture, data augmentation, GPU support, and dynamic learning rate scheduling.
https://github.com/not-ml/ml-3
ai cifar10 cnn cuda gpu image-classification machine-learning modeltraining python pytorch torchvision
Last synced: 28 days ago
JSON representation
A PyTorch-based Convolutional Neural Network (CNN) for image classification using the CIFAR-10 dataset, featuring advanced architecture, data augmentation, GPU support, and dynamic learning rate scheduling.
- Host: GitHub
- URL: https://github.com/not-ml/ml-3
- Owner: Not-ML
- License: mit
- Created: 2024-11-22T11:19:04.000Z (3 months ago)
- Default Branch: main
- Last Pushed: 2024-11-22T11:38:38.000Z (3 months ago)
- Last Synced: 2025-01-23T07:47:43.350Z (28 days ago)
- Topics: ai, cifar10, cnn, cuda, gpu, image-classification, machine-learning, modeltraining, python, pytorch, torchvision
- Language: Python
- Homepage:
- Size: 15.6 KB
- Stars: 0
- Watchers: 1
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# data:image/s3,"s3://crabby-images/8e95f/8e95fedf9e7a9cb9d022832bec4b4f37e9a47fdf" alt="pytorch-logo-dark"
# Advanced CNN for CIFAR-10 Classification
This project implements an advanced Convolutional Neural Network (CNN) in PyTorch to classify images from the CIFAR-10 dataset into 10 categories: `airplane`, `automobile`, `bird`, `cat`, `deer`, `dog`, `frog`, `horse`, `ship`, and `truck`.
## Features
- **Pretrained Dataset**: Utilizes CIFAR-10, a 60,000-image dataset split into 50,000 training and 10,000 test images.
- **Customizable CNN Model**: Includes multiple convolutional, batch normalization, pooling, and fully connected layers for accurate classification.
- **GPU Support**: Automatically utilizes CUDA if available for faster computation.
- **Data Augmentation**: Enhances training with random cropping, flipping, and normalization.
- **Training Scheduler**: Reduces learning rate dynamically for fine-tuning.## Setup and Requirements
### Prerequisites
- Python 3.x
- Required packages: `torch`, `torchvision`, `tqdm`, `Pillow`Install dependencies with:
```bash
pip install torch torchvision tqdm Pillow
```### Clone the Repository
```bash
git clone https://github.com/your-username/advanced-cnn-cifar10.git
cd advanced-cnn-cifar10
```## Training the Model
Run the script to train the model:
```bash
python advanced_cnn_cifar10.py
```
- Adjust `num_epochs`, `batch_size`, or `learning_rate` in the script for your requirements.
- Trained models are saved in the `checkpoints/` directory.## Inference
### Predict Single Image
```python
from predict import predict_imagemodel_path = 'checkpoints/advanced_cnn.pth'
image_path = 'path/to/your/image.jpg'prediction = predict_image(image_path, model_path)
print(f'Predicted class: {prediction}')
```### Predict Batch of Images
```python
from predict import predict_batchbatch_predictions = predict_batch('path/to/dataset', model_path)
print(batch_predictions)
```## Results
- Achieves 85-90% accuracy on the CIFAR-10 test set after 30 epochs.
- Checkpoint and accuracy logs are generated during training.## Additional Notes
- Modify the architecture or hyperparameters to experiment with different configurations.
- GPU is highly recommended for faster training.For more details, check the code files: `advanced_cnn_cifar10.py` (training) and `predict.py` (inference).