https://github.com/muhd-umer/torch-classification
PyTorch-based image classification with super-resolution support
https://github.com/muhd-umer/torch-classification
cnn computer-vision deep-learning image-classification pytorch
Last synced: 5 months ago
JSON representation
PyTorch-based image classification with super-resolution support
- Host: GitHub
- URL: https://github.com/muhd-umer/torch-classification
- Owner: muhd-umer
- License: mit
- Created: 2023-12-13T16:03:52.000Z (over 2 years ago)
- Default Branch: main
- Last Pushed: 2024-01-02T14:59:28.000Z (over 2 years ago)
- Last Synced: 2025-04-06T21:49:28.607Z (about 1 year ago)
- Topics: cnn, computer-vision, deep-learning, image-classification, pytorch
- Language: Python
- Homepage:
- Size: 5.36 MB
- Stars: 1
- Watchers: 1
- Forks: 2
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Torch Classification
[](https://opensource.org/licenses/MIT) [](https://pytorch.org/) [](https://www.cs.toronto.edu/~kriz/cifar.html)
Torch Classification is a PyTorch-based image classification project showcasing the implementation of the [EfficientNet V2](https://arxiv.org/abs/2104.00298) family to classify images. This project covers training the model from scratch and employing transfer learning with pre-trained weights specifically on the CIFAR-100 dataset. Additionally, it delves into the impact of leveraging GANs (BSRGAN & SwinIR) for image super-resolution on the same CIFAR-100 dataset. This initiative was undertaken as part of a Machine Learning course at NUST, emphasizing practical applications of deep learning.
## Installation
To get started with this project, follow the steps below:
- Clone the repository to your local machine using the following command:
```fish
git clone https://github.com/muhd-umer/torch-classification.git
```
- It is recommended to create a new virtual environment so that updates/downgrades of packages do not break other projects. To create a new virtual environment, run the following command:
```fish
conda env create -f environment.yml
```
- Alternatively, you can use `mamba` (faster than conda) package manager to create a new virtual environment:
```fish
wget -O miniforge.sh \
"https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-$(uname)-$(uname -m).sh"
bash miniforge.sh -b -p "${HOME}/conda"
source "${HOME}/conda/etc/profile.d/conda.sh"
# For mamba support also run the following command
source "${HOME}/conda/etc/profile.d/mamba.sh"
conda activate
mamba env create -f environment.yml
```
- Activate the newly created environment:
```fish
conda activate torch-classification
```
- Install the PyTorch Ecosystem:
```fish
# pip will take care of necessary CUDA packages
pip3 install torch torchvision torchaudio
# additional packages (already included in environment.yml)
pip3 install einops python-box timm torchinfo \
lightning rich wandb rawpy
```
## Dataset
The CIFAR-100 dataset is used for training and testing the model. The dataset can be downloaded from [here](https://www.cs.toronto.edu/~kriz/cifar.html).
Or, you can use the following commands to download the dataset:
```fish
# download as python pickle
cd data
curl -O https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz
tar -xvzf cifar-100-python.tar.gz
# download as ImageNet format
pip3 install cifar2png
cifar2png cifar100 data/cifar100
```
We also offer super-resolution variants of the CIFAR-100 dataset, which have upscaled the images to `128x128` resolution using [BSRGAN 4x](https://github.com/cszn/BSRGAN) and [SwinIR](https://github.com/JingyunLiang/SwinIR). You can download these dataset from the [Weights & Data](https://github.com/muhd-umer/torch-classification/releases/) section. Or, you can use the following commands to download the dataset:
```fish
wget -O data/bsrgan_4x_cifar100.zip \
"https://github.com/muhd-umer/torch-classification/releases/download/v0.0.1/bsrgan_4x_cifar100.zip"
# unzip the dataset
unzip -q data/bsrgan_4x_cifar100.zip -d data/
# or
wget -O data/swinir_4x_cifar100.zip \
"https://github.com/muhd-umer/torch-classification/releases/download/v0.0.1/swinir_4x_cifar100.zip"
# unzip the dataset
unzip -q data/swinir_4x_cifar100.zip -d data/
```
## Usage
To train the model from scratch, run the following command:
```fish
# train the model from scratch using default config
python3 train.py
# train the model from scratch using overrides
python3 train.py --mode MODE \ # (train, finetune)
--data-dir DATA_DIR \ # directory containing data
--model-dir MODEL_DIR \ # directory to save model
--batch-size BATCH_SIZE \ # batch size
--dataset-type DATASET_TYPE \ # (default, imagefolder)
--num-workers NUM_WORKERS \ # number of workers
--num-epochs NUM_EPOCHS \ # number of epochs
--lr LR \ # learning rate
--rich-progress \ # use rich progress bar
--accelerator ACCELERATOR \ # type of accelerator
--devices DEVICES \ # number of devices
--weights WEIGHTS \ # path to weights file
--resume \ # resume training from checkpoint
--test-only \ # test the model on test set
--logger-backend LOGGER_BACKEND # (wandb, tensorboard)
```
To evaluate the models, download the appropriate weights from the [Weights & Data](https://github.com/muhd-umer/torch-classification/releases/) section and place them in `weights/` directory. Then, run the following command:
```fish
bash run.sh
# or
python3 train.py --weights WEIGHTS --test-only
```
## Project Structure
The project is structured as follows:
```fish
torch-classification
├── data/ # data directory
├── models/ # model directory
├── resources/ # resources directory
├── utils/ # utility directory
├── LICENSE # license file
├── README.md # readme file
├── environment.yml # conda environment file
├── upscale.py # upscaling script
└── train.py # training script
```
## Contributing ❤️
Pull requests are welcome. For major changes, please open an issue first to discuss what you would like to change.