Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/debugger404/cifar-pytorch_model
Train Basic Model on CIFAR10 Dataset - 🎨🖥️ Utilizes CIFAR-10 dataset with 60000 32x32 color images in 10 classes. Demonstrates loading using torchvision and training with pretrained models like ResNet18, AlexNet, VGG16, DenseNet161, and Inception. Notebook available for experimentation.
https://github.com/debugger404/cifar-pytorch_model
cifar10 cifar10-classification dataset image-classification python pytorch training
Last synced: 4 days ago
JSON representation
Train Basic Model on CIFAR10 Dataset - 🎨🖥️ Utilizes CIFAR-10 dataset with 60000 32x32 color images in 10 classes. Demonstrates loading using torchvision and training with pretrained models like ResNet18, AlexNet, VGG16, DenseNet161, and Inception. Notebook available for experimentation.
- Host: GitHub
- URL: https://github.com/debugger404/cifar-pytorch_model
- Owner: deBUGger404
- Created: 2021-05-08T16:06:27.000Z (over 3 years ago)
- Default Branch: main
- Last Pushed: 2021-05-09T05:39:14.000Z (over 3 years ago)
- Last Synced: 2024-11-05T09:48:38.259Z (about 2 months ago)
- Topics: cifar10, cifar10-classification, dataset, image-classification, python, pytorch, training
- Language: Jupyter Notebook
- Homepage:
- Size: 163 KB
- Stars: 1
- Watchers: 1
- Forks: 1
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# Train Basic Model on CIFAR10-Dataset
## Contents
- [Introduction](#introduction)
- [Prerequisites](#prerequisites)
- [Training](#training)## Introduction
The `CIFAR-10` dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.below is the 6 random images with their respective label:
There is a package of python called `torchvision`, that has data loaders for `CIFAR10` and data transformers for images using `torch.utils.data.DataLoader`.
Below an example of how to load `CIFAR10` dataset using `torchvision`:
```python
import torch
import torchvision
## load data CIFAR10
train_dataset = torchvision.datasets.CIFAR10(root='./train_data', train=True, download=True)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
```## Prerequisites
- Python>=3.6
- PyTorch >=1.4
- Library are mentioned in `requirenments.txt`## Training
I used pretrained `resnet18` for model training. you can use any other pretrained model according to you problem.
```python
import torchvision.models as models
alexnet = models.alexnet()
vgg16 = models.vgg16()
densenet = models.densenet161()
inception = models.inception_v3()
```
There are two things for pytorch model training:
1. Notebook - you can just download and play with it
2. python scripts:
```
# Start training with:
python main.py
# You can manually pass the attributes for the training:
python main.py --lr=0.01 --epoch 20 --model_path './cifar_model.pth'
# Start infrence with:
python3.6 prediction.py --model_path './cifar_model.pth'
```# Give a :star: to this Repository!