Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/debugger404/cifar-pytorch_model

Train Basic Model on CIFAR10 Dataset - 🎨🖥️ Utilizes CIFAR-10 dataset with 60000 32x32 color images in 10 classes. Demonstrates loading using torchvision and training with pretrained models like ResNet18, AlexNet, VGG16, DenseNet161, and Inception. Notebook available for experimentation.
https://github.com/debugger404/cifar-pytorch_model

cifar10 cifar10-classification dataset image-classification python pytorch training

Last synced: 4 days ago
JSON representation

Train Basic Model on CIFAR10 Dataset - 🎨🖥️ Utilizes CIFAR-10 dataset with 60000 32x32 color images in 10 classes. Demonstrates loading using torchvision and training with pretrained models like ResNet18, AlexNet, VGG16, DenseNet161, and Inception. Notebook available for experimentation.

Awesome Lists containing this project

README

        

# Train Basic Model on CIFAR10-Dataset





## Contents
- [Introduction](#introduction)
- [Prerequisites](#prerequisites)
- [Training](#training)

## Introduction
The `CIFAR-10` dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.

below is the 6 random images with their respective label:

There is a package of python called `torchvision`, that has data loaders for `CIFAR10` and data transformers for images using `torch.utils.data.DataLoader`.

Below an example of how to load `CIFAR10` dataset using `torchvision`:

```python
import torch
import torchvision
## load data CIFAR10
train_dataset = torchvision.datasets.CIFAR10(root='./train_data', train=True, download=True)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)
```

## Prerequisites
- Python>=3.6
- PyTorch >=1.4
- Library are mentioned in `requirenments.txt`

## Training
I used pretrained `resnet18` for model training. you can use any other pretrained model according to you problem.
```python
import torchvision.models as models
alexnet = models.alexnet()
vgg16 = models.vgg16()
densenet = models.densenet161()
inception = models.inception_v3()
```
There are two things for pytorch model training:
1. Notebook - you can just download and play with it
2. python scripts:
```
# Start training with:
python main.py

# You can manually pass the attributes for the training:
python main.py --lr=0.01 --epoch 20 --model_path './cifar_model.pth'

# Start infrence with:
python3.6 prediction.py --model_path './cifar_model.pth'
```

# Give a :star: to this Repository!