https://github.com/horrible-dong/qtclassification
A lightweight and extensible toolbox for image classification and MORE
https://github.com/horrible-dong/qtclassification
amp cifar10 cifar100 configs deep-learning extensible image-classfication imagenet lightweight machine-learning pytorch resnet template toolbox vision-transformer
Last synced: 4 months ago
JSON representation
A lightweight and extensible toolbox for image classification and MORE
- Host: GitHub
- URL: https://github.com/horrible-dong/qtclassification
- Owner: horrible-dong
- License: apache-2.0
- Created: 2023-03-18T09:59:40.000Z (over 3 years ago)
- Default Branch: main
- Last Pushed: 2026-02-22T13:23:50.000Z (4 months ago)
- Last Synced: 2026-02-22T18:16:37.688Z (4 months ago)
- Topics: amp, cifar10, cifar100, configs, deep-learning, extensible, image-classfication, imagenet, lightweight, machine-learning, pytorch, resnet, template, toolbox, vision-transformer
- Language: Python
- Homepage:
- Size: 393 KB
- Stars: 19
- Watchers: 2
- Forks: 1
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
QTClassification
========
**A lightweight and extensible toolbox for image classification and MORE**
[](https://github.com/horrible-dong/QTClassification)
â[](README.md)
â[](LICENSE)
> Author: QIU Tian
> Affiliation: Zhejiang University
> đ ī¸ Installation | đ
> Documentation | đą Dataset Zoo | đ Model Zoo
> English | [įŽäŊ䏿](README_zh-CN.md)
## Installation
The development environment of this project is `python 3.8 & pytorch 1.13.1+cu117`.
1. Create your conda environment.
```bash
conda create -n qtcls python==3.8 -y
```
2. Enter your conda environment.
```bash
conda activate qtcls
```
3. Install PyTorch.
```bash
pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117
```
Or you can refer to [PyTorch](https://pytorch.org/get-started/previous-versions/) to install newer or older versions.
Please note that if pytorch âĨ 1.13, then python âĨ 3.7.2 is required.
4. Install necessary dependencies.
```bash
pip install -r requirements.txt
```
## Getting Started
For a quick experience, you can directly run the following commands:
**Training**
```bash
# single-gpu
CUDA_VISIBLE_DEVICES=0 \
python main.py \
--data_root ./data \
--dataset cifar10 \
--model resnet50 \
--batch_size 256 \
--lr 1e-4 \
--epochs 12 \
--output_dir ./runs/__tmp__
# multi-gpu (requires pytorch>=1.9.0)
OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1 \
torchrun --nproc_per_node=2 main.py \
--data_root ./data \
--dataset cifar10 \
--model resnet50 \
--batch_size 256 \
--lr 1e-4 \
--epochs 12 \
--output_dir ./runs/__tmp__
# multi-gpu (for any pytorch version, but with a "deprecated" warning)
OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1 \
python -m torch.distributed.launch --nproc_per_node=2 --use_env main.py \
--data_root ./data \
--dataset cifar10 \
--model resnet50 \
--batch_size 256 \
--lr 1e-4 \
--epochs 12 \
--output_dir ./runs/__tmp__
```
The `cifar10` dataset and `resnet50` pretrained weights will be automatically downloaded. Please keep the network
accessible. The `cifar10` dataset will be downloaded to `./data`. The `resnet50` pretrained weights will be downloaded
to `~/.cache/torch/hub/checkpoints`.
During training, the config file, checkpoints, logs, and other outputs will be stored in `./runs/__tmp__`.
**Evaluation**
```bash
# single-gpu
CUDA_VISIBLE_DEVICES=0 \
python main.py \
--data_root ./data \
--dataset cifar10 \
--model resnet50 \
--batch_size 256 \
--resume ./runs/__tmp__/checkpoint.pth \
--eval
# multi-gpu (requires pytorch>=1.9.0)
OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1 \
torchrun --nproc_per_node=2 main.py \
--data_root ./data \
--dataset cifar10 \
--model resnet50 \
--batch_size 256 \
--resume ./runs/__tmp__/checkpoint.pth \
--eval
# multi-gpu (for any pytorch version, but with a "deprecated" warning)
OMP_NUM_THREADS=1 CUDA_VISIBLE_DEVICES=0,1 \
python -m torch.distributed.launch --nproc_per_node=2 --use_env main.py \
--data_root ./data \
--dataset cifar10 \
--model resnet50 \
--batch_size 256 \
--resume ./runs/__tmp__/checkpoint.pth \
--eval
```
### Using a config file (Recommended)
You can also write the arguments into a config file (.py) and use `--config` / `-c` to import it.
See [configs](configs).
**Training**
```bash
# full command
python main.py --config /path/to/config.py
# short command
python main.py -c /path/to/config.py
# example
python main.py -c configs/_demo_.py
```
**Evaluation**
```bash
# full command
python main.py --config /path/to/config.py --resume /path/to/checkpoint.pth --eval
# short command
python main.py -c /path/to/config.py -r /path/to/checkpoint.pth --eval
# example
python main.py -c configs/_demo_.py -r ./runs/cifar10/vit_tiny_patch4_32/checkpoint.pth --eval
```
The config arguments override or merge with the command-line arguments `args` pre-defined in [`main.py`](main.py).
**Other examples**
```bash
python main.py -c configs/_demo_.py -co # clear the output dir first
python main.py -c configs/_demo_.py --batch_size 100 --print_freq 200 --note bs100
python main.py -c configs/_demo_.py --save_interval 5555 # do not save
python main.py -c configs/_demo_.py --dataset food --dummy # use fake data
python main.py -c configs/_demo_.py -d cifar100 -b 400 --note cifar100-bs400
torchrun --nproc_per_node=2 main.py -c configs/_demo_.py # multi-gpu
```
Command-line arguments after `--config xxx` / `-c xxx` override the config arguments if the name is duplicated.
For more details and advanced usage of config files,
please refer to ["How to write and import your configs"](configs/README.md).
**Frequently-used command-line arguments**
| Command-Line Argument | Description | Default Value |
|:-------------------------------:|:-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------:|:----------------:|
| `--data_root` | Root directory where datasets are stored. | `./data` |
| `--dataset`
`-d` | Dataset name defined in [qtcls/datasets/\_\_init\_\_.py](qtcls/datasets/__init__.py), such as `cifar10` and `imagenet1k`. | / |
| `--dummy` | Use fake data of `--dataset` / `-d` instead of loading real data (for fast debugging when no data is available or data loading is slow). | `False` |
| `--model_lib` | Model library where models come from. The toolbox's basic (default) model library is extended from `torchvision` and `timm`, and the toolbox also supports the original `timm`. | `default` |
| `--model`
`-m` | Model name defined in [qtcls/models/\_\_init\_\_.py](qtcls/models/__init__.py), such as `resnet50` and `vit_b_16`. Currently supported model names are listed in Model Zoo. | / |
| `--criterion` | Criterion name defined in [qtcls/criterions/\_\_init\_\_.py](qtcls/criterions/__init__.py), such as `ce`. | `default` |
| `--optimizer` | Optimizer name defined in [qtcls/optimizers/\_\_init\_\_.py](qtcls/optimizers/__init__.py), such as `sgd` and `adam`. | `adamw` |
| `--scheduler` | Scheduler name defined in [qtcls/schedulers/\_\_init\_\_.py](qtcls/schedulers/__init__.py), such as `cosine`. | `cosine` |
| `--evaluator` | Evaluator name defined in [qtcls/evaluators/\_\_init\_\_.py](qtcls/evaluators/__init__.py). The `default` evaluator computes the accuracy, recall, precision, and f1_score. | `default` |
| `--pretrain`
`-p` | Path to the pre-trained weights, which is of the higher priority than the path stored in [qtcls/models/\_pretrain\_.py](qtcls/models/_pretrain_.py). For long-term use of a pretrained weight path, it is preferable to write it in [qtcls/models/\_pretrain\_.py](qtcls/models/_pretrain_.py). | / |
| `--no_pretrain` | Forcibly not use the pre-trained weights. | `False` |
| `--resume`
`-r` | Checkpoint path to resume from. | / |
| `--output_dir`
`-o` | Path to save checkpoints, logs, and other outputs. | `./runs/__tmp__` |
| `--save_interval` | Interval for saving checkpoints. | `1` |
| `--clear_output_dir`
`-co` | Clear the output dir (`--output_dir` / `-o`) first. | `False` |
| `--comp_dirs` | List of log/output dir paths to compare. The compared logs will be plotted with the current log in the same figure. | `[]` |
| `--batch_size`
`-b` | / | `256` |
| `--epochs` | / | `300` |
| `--lr` | Learning rate. | `1e-4` |
| `--amp` | Enable automatic mixed precision training (faster, less GPU memory). | `False` |
| `--eval` | Evaluate only. | `False` |
| `--note` | Note. The note content prints after each epoch, in case you forget what you are running. | / |
| `--print_freq` | Print to the terminal every N iterations. | `50` |
### How to prepare data
Currently, `mnist`, `fashion_mnist`, `cifar10`, `cifar100`, `stl10`, `svhn`, `pets`, `flowers`, `cars` and `food`
datasets will be automatically downloaded to the `--data_root` directory. For other datasets, please refer to
["How to put your datasets"](data/README.md).
### How to customize
The toolbox is flexible enough to be extended. Please follow the instructions below:
[How to write and import your configs](configs/README.md)
[How to put your datasets](data/README.md)
[How to register your datasets](qtcls/datasets/README.md)
[How to register your models](qtcls/models/README.md)
[How to register your criterions](qtcls/criterions/README.md)
[How to register your optimizers](qtcls/optimizers/README.md)
[How to register your schedulers](qtcls/schedulers/README.md)
[How to register your evaluators](qtcls/evaluators/README.md)
## Dataset Zoo
Currently supported argument `--dataset` / `-d`:
`mnist`, `fashion_mnist`, `cifar10`, `cifar100`, `stl10`, `svhn`, `pets`, `flowers`, `cars`, `food`,
`imagenet1k`, `imagenet21k (also called imagenet22k)`,
and all datasets in `folder` format (consistent with `imagenet` storage format,
see ["How to put your datasets - About folder format datasets"](data/README.md) for details).
## Model Zoo
The toolbox's basic (default) model library is extended from `torchvision` and `timm`,
and the toolbox also supports the original `timm`.
### default
Set the argument `--model_lib` to `default`.
Currently supported argument `--model` / `-m`:
**AlexNet**
`alexnet`
**CaiT**
`cait_xxs24_224`, `cait_xxs24_384`, `cait_xxs36_224`, `cait_xxs36_384`, `cait_xs24_384`, `cait_s24_224`, `cait_s24_384`,
`cait_s36_384`, `cait_m36_384`, `cait_m48_448`
**ConvNeXt**
`convnext_tiny`, `convnext_small`, `convnext_base`, `convnext_large`
**DeiT**
`deit_tiny_patch16_224`, `deit_small_patch16_224`, `deit_base_patch16_224`, `deit_base_patch16_384`,
`deit_tiny_distilled_patch16_224`, `deit_small_distilled_patch16_224`, `deit_base_distilled_patch16_224`,
`deit_base_distilled_patch16_384`, `deit3_small_patch16_224`, `deit3_small_patch16_384`, `deit3_medium_patch16_224`,
`deit3_base_patch16_224`, `deit3_base_patch16_384`, `deit3_large_patch16_224`, `deit3_large_patch16_384`,
`deit3_huge_patch14_224`, `deit3_small_patch16_224_in21ft1k`, `deit3_small_patch16_384_in21ft1k`,
`deit3_medium_patch16_224_in21ft1k`, `deit3_base_patch16_224_in21ft1k`, `deit3_base_patch16_384_in21ft1k`,
`deit3_large_patch16_224_in21ft1k`, `deit3_large_patch16_384_in21ft1k`, `deit3_huge_patch14_224_in21ft1k`
**DenseNet**
`densenet121`, `densenet169`, `densenet201`, `densenet161`
**EfficientNet**
`efficientnet_b0`, `efficientnet_b1`, `efficientnet_b2`, `efficientnet_b3`, `efficientnet_b4`, `efficientnet_b5`,
`efficientnet_b6`, `efficientnet_b7`
**GoogLeNet**
`googlenet`
**Inception**
`inception_v3`
**LeViT**
`levit_128s`, `levit_128`, `levit_192`, `levit_256`, `levit_256d`, `levit_384`
**MLP-Mixer**
`mixer_s32_224`, `mixer_s16_224`, `mixer_b32_224`, `mixer_b16_224`, `mixer_b16_224_in21k`, `mixer_l32_224`,
`mixer_l16_224`, `mixer_l16_224_in21k`, `mixer_b16_224_miil_in21k`, `mixer_b16_224_miil`, `gmixer_12_224`,
`gmixer_24_224`, `resmlp_12_224`, `resmlp_24_224`, `resmlp_36_224`, `resmlp_big_24_224`, `resmlp_12_distilled_224`,
`resmlp_24_distilled_224`, `resmlp_36_distilled_224`, `resmlp_big_24_distilled_224`, `resmlp_big_24_224_in22ft1k`,
`resmlp_12_224_dino`, `resmlp_24_224_dino`, `gmlp_ti16_224`, `gmlp_s16_224`, `gmlp_b16_224`
**MNASNet**
`mnasnet0_5`, `mnasnet0_75`, `mnasnet1_0`, `mnasnet1_3`
**MobileNet**
`mobilenet_v2`, `mobilenetv3`, `mobilenet_v3_large`, `mobilenet_v3_small`
**PoolFormer**
`poolformer_s12`, `poolformer_s24`, `poolformer_s36`, `poolformer_m36`, `poolformer_m48`
**PVT**
`pvt_tiny`, `pvt_small`, `pvt_medium`, `pvt_large`, `pvt_huge_v2`
**RegNet**
`regnet_y_400mf`, `regnet_y_800mf`, `regnet_y_1_6gf`, `regnet_y_3_2gf`, `regnet_y_8gf`, `regnet_y_16gf`,
`regnet_y_32gf`, `regnet_y_128gf`, `regnet_x_400mf`, `regnet_x_800mf`, `regnet_x_1_6gf`, `regnet_x_3_2gf`,
`regnet_x_8gf`, `regnet_x_16gf`, `regnet_x_32gf`
**ResNet**
`resnet18`, `resnet34`, `resnet50`, `resnet101`, `resnet152`, `resnext50_32x4d`, `resnext101_32x8d`, `wide_resnet50_2`,
`wide_resnet101_2`
**ShuffleNetV2**
`shufflenet_v2_x0_5`, `shufflenet_v2_x1_0`, `shufflenet_v2_x1_5`, `shufflenet_v2_x2_0`
**SqueezeNet**
`squeezenet1_0`, `squeezenet1_1`
**Swin Transformer**
`swin_tiny_patch4_window7_224`, `swin_small_patch4_window7_224`, `swin_base_patch4_window7_224`,
`swin_base_patch4_window12_384`, `swin_large_patch4_window7_224`, `swin_large_patch4_window12_384`,
`swin_base_patch4_window7_224_in22k`, `swin_base_patch4_window12_384_in22k`, `swin_large_patch4_window7_224_in22k`,
`swin_large_patch4_window12_384_in22k`
**Swin Transformer V2**
`swinv2_tiny_window8_256`, `swinv2_tiny_window16_256`, `swinv2_small_window8_256`, `swinv2_small_window16_256`,
`swinv2_base_window8_256`, `swinv2_base_window16_256`, `swinv2_base_window12_192_22k`,
`swinv2_base_window12to16_192to256_22kft1k`, `swinv2_base_window12to24_192to384_22kft1k`,
`swinv2_large_window12_192_22k`, `swinv2_large_window12to16_192to256_22kft1k`,
`swinv2_large_window12to24_192to384_22kft1k`
**TNT**
`tnt_s_patch16_224`, `tnt_b_patch16_224`
**Twins**
`twins_pcpvt_small`, `twins_pcpvt_base`, `twins_pcpvt_large`, `twins_svt_small`, `twins_svt_base`, `twins_svt_large`
**VGG**
`vgg11`, `vgg11_bn`, `vgg13`, `vgg13_bn`, `vgg16`, `vgg16_bn`, `vgg19`, `vgg19_bn`
**Vision Transformer (timm)**
`vit_tiny_patch4_32`, `vit_tiny_patch16_224`, `vit_tiny_patch16_384`, `vit_small_patch32_224`, `vit_small_patch32_384`,
`vit_small_patch16_224`, `vit_small_patch16_384`, `vit_small_patch8_224`, `vit_base_patch32_224`,
`vit_base_patch32_384`, `vit_base_patch16_224`, `vit_base_patch16_384`, `vit_base_patch8_224`, `vit_large_patch32_224`,
`vit_large_patch32_384`, `vit_large_patch16_224`, `vit_large_patch16_384`, `vit_large_patch14_224`,
`vit_huge_patch14_224`, `vit_giant_patch14_224`
**Vision Transformer (torchvision)**
`vit_b_16`, `vit_b_32`, `vit_l_16`, `vit_l_32`
### timm
Set the argument `--model_lib` to `timm`.
Currently supported argument `--model` / `-m`:
All supported. Please refer to `timm` for specific model names.
## License
QTClassification is released under the Apache 2.0 license. Please see the [LICENSE](LICENSE) file for more information.
Copyright (c) QIU Tian. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use these files except in compliance with
the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0.
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific
language governing permissions and limitations under the License.
## Citation
If you find QTClassification Toolbox useful in your research, please consider citing:
```bibtex
@misc{qtcls,
title={QTClassification},
author={Qiu, Tian},
howpublished={\url{https://github.com/horrible-dong/QTClassification}},
year={2023}
}
```