An open API service indexing awesome lists of open source software.

https://github.com/megvii-research/mdistiller

The official implementation of [CVPR2022] Decoupled Knowledge Distillation https://arxiv.org/abs/2203.08679 and [ICCV2023] DOT: A Distillation-Oriented Trainer https://openaccess.thecvf.com/content/ICCV2023/papers/Zhao_DOT_A_Distillation-Oriented_Trainer_ICCV_2023_paper.pdf
https://github.com/megvii-research/mdistiller

cifar coco computer-vision cvpr2022 deep-learning iccv2023 imagenet knowledge-distillation pytorch

Last synced: 4 months ago
JSON representation

The official implementation of [CVPR2022] Decoupled Knowledge Distillation https://arxiv.org/abs/2203.08679 and [ICCV2023] DOT: A Distillation-Oriented Trainer https://openaccess.thecvf.com/content/ICCV2023/papers/Zhao_DOT_A_Distillation-Oriented_Trainer_ICCV_2023_paper.pdf

Awesome Lists containing this project

README

          


This repo is

(1) a PyTorch library that provides classical knowledge distillation algorithms on mainstream CV benchmarks,

(2) the official implementation of the CVPR-2022 paper: [Decoupled Knowledge Distillation](https://arxiv.org/abs/2203.08679).

(3) the official implementation of the ICCV-2023 paper: [DOT: A Distillation-Oriented Trainer](https://openaccess.thecvf.com/content/ICCV2023/papers/Zhao_DOT_A_Distillation-Oriented_Trainer_ICCV_2023_paper.pdf).

# DOT: A Distillation-Oriented Trainer

### Framework

### Main Benchmark Results

On CIFAR-100:

| Teacher
Student | ResNet32x4
ResNet8x4| VGG13
VGG8| ResNet32x4
ShuffleNet-V2|
|:---------------:|:-----------------:|:-----------------:|:-----------------:|
| KD | 73.33 | 72.98 | 74.45 |
| **KD+DOT** | **75.12** | **73.77** | **75.55** |

On Tiny-ImageNet:

| Teacher
Student |ResNet18
MobileNet-V2|ResNet18
ShuffleNet-V2|
|:---------------:|:-----------------:|:-----------------:|
| KD | 58.35 | 62.26 |
| **KD+DOT** | **64.01** | **65.75** |

On ImageNet:

| Teacher
Student |ResNet34
ResNet18|ResNet50
MobileNet-V1|
|:---------------:|:-----------------:|:-----------------:|
| KD | 71.03 | 70.50 |
| **KD+DOT** | **71.72** | **73.09** |

# Decoupled Knowledge Distillation

### Framework & Performance

### Main Benchmark Results

On CIFAR-100:

| Teacher
Student |ResNet56
ResNet20|ResNet110
ResNet32| ResNet32x4
ResNet8x4| WRN-40-2
WRN-16-2| WRN-40-2
WRN-40-1 | VGG13
VGG8|
|:---------------:|:-----------------:|:-----------------:|:-----------------:|:------------------:|:------------------:|:--------------------:|
| KD | 70.66 | 73.08 | 73.33 | 74.92 | 73.54 | 72.98 |
| **DKD** | **71.97** | **74.11** | **76.32** | **76.23** | **74.81** | **74.68** |

| Teacher
Student |ResNet32x4
ShuffleNet-V1|WRN-40-2
ShuffleNet-V1| VGG13
MobileNet-V2| ResNet50
MobileNet-V2| ResNet32x4
MobileNet-V2|
|:---------------:|:-----------------:|:-----------------:|:-----------------:|:------------------:|:------------------:|
| KD | 74.07 | 74.83 | 67.37 | 67.35 | 74.45 |
| **DKD** | **76.45** | **76.70** | **69.71** | **70.35** | **77.07** |

On ImageNet:

| Teacher
Student |ResNet34
ResNet18|ResNet50
MobileNet-V1|
|:---------------:|:-----------------:|:-----------------:|
| KD | 71.03 | 70.50 |
| **DKD** | **71.70** | **72.05** |

# MDistiller

### Introduction

MDistiller supports the following distillation methods on CIFAR-100, ImageNet and MS-COCO:
|Method|Paper Link|CIFAR-100|ImageNet|MS-COCO|
|:---:|:---:|:---:|:---:|:---:|
|KD| |✓|✓| |
|FitNet| |✓| | |
|AT| |✓|✓| |
|NST| |✓| | |
|PKT| |✓| | |
|KDSVD| |✓| | |
|OFD| |✓|✓| |
|RKD| |✓| | |
|VID| |✓| | |
|SP| |✓| | |
|CRD| |✓|✓| |
|ReviewKD| |✓|✓|✓|
|DKD| |✓|✓|✓|

### Installation

Environments:

- Python 3.6
- PyTorch 1.9.0
- torchvision 0.10.0

Install the package:

```
sudo pip3 install -r requirements.txt
sudo python3 setup.py develop
```

### Getting started

0. Wandb as the logger

- The registeration: .
- If you don't want wandb as your logger, set `CFG.LOG.WANDB` as `False` at `mdistiller/engine/cfg.py`.

1. Evaluation

- You can evaluate the performance of our models or models trained by yourself.

- Our models are at , please download the checkpoints to `./download_ckpts`

- If test the models on ImageNet, please download the dataset at and put them to `./data/imagenet`

```bash
# evaluate teachers
python3 tools/eval.py -m resnet32x4 # resnet32x4 on cifar100
python3 tools/eval.py -m ResNet34 -d imagenet # ResNet34 on imagenet

# evaluate students
python3 tools/eval.p -m resnet8x4 -c download_ckpts/dkd_resnet8x4 # dkd-resnet8x4 on cifar100
python3 tools/eval.p -m MobileNetV1 -c download_ckpts/imgnet_dkd_mv1 -d imagenet # dkd-mv1 on imagenet
python3 tools/eval.p -m model_name -c output/your_exp/student_best # your checkpoints
```

2. Training on CIFAR-100

- Download the `cifar_teachers.tar` at and untar it to `./download_ckpts` via `tar xvf cifar_teachers.tar`.

```bash
# for instance, our DKD method.
python3 tools/train.py --cfg configs/cifar100/dkd/res32x4_res8x4.yaml

# you can also change settings at command line
python3 tools/train.py --cfg configs/cifar100/dkd/res32x4_res8x4.yaml SOLVER.BATCH_SIZE 128 SOLVER.LR 0.1
```

3. Training on ImageNet

- Download the dataset at and put them to `./data/imagenet`

```bash
# for instance, our DKD method.
python3 tools/train.py --cfg configs/imagenet/r34_r18/dkd.yaml
```

4. Training on MS-COCO

- see [detection.md](detection/README.md)

5. Extension: Visualizations

- Jupyter notebooks: [tsne](tools/visualizations/tsne.ipynb) and [correlation_matrices](tools/visualizations/correlation.ipynb)

### Custom Distillation Method

1. create a python file at `mdistiller/distillers/` and define the distiller

```python
from ._base import Distiller

class MyDistiller(Distiller):
def __init__(self, student, teacher, cfg):
super(MyDistiller, self).__init__(student, teacher)
self.hyper1 = cfg.MyDistiller.hyper1
...

def forward_train(self, image, target, **kwargs):
# return the output logits and a Dict of losses
...
# rewrite the get_learnable_parameters function if there are more nn modules for distillation.
# rewrite the get_extra_parameters if you want to obtain the extra cost.
...
```

2. regist the distiller in `distiller_dict` at `mdistiller/distillers/__init__.py`

3. regist the corresponding hyper-parameters at `mdistiller/engines/cfg.py`

4. create a new config file and test it.

# Citation

If this repo is helpful for your research, please consider citing the paper:

```BibTeX
@article{zhao2022dkd,
title={Decoupled Knowledge Distillation},
author={Zhao, Borui and Cui, Quan and Song, Renjie and Qiu, Yiyu and Liang, Jiajun},
journal={arXiv preprint arXiv:2203.08679},
year={2022}
}
@article{zhao2023dot,
title={DOT: A Distillation-Oriented Trainer},
author={Zhao, Borui and Cui, Quan and Song, Renjie and Liang, Jiajun},
journal={arXiv preprint arXiv:2307.08436},
year={2023}
}
```

# License

MDistiller is released under the MIT license. See [LICENSE](LICENSE) for details.

# Acknowledgement

- Thanks for CRD and ReviewKD. We build this library based on the [CRD's codebase](https://github.com/HobbitLong/RepDistiller) and the [ReviewKD's codebase](https://github.com/dvlab-research/ReviewKD).

- Thanks Yiyu Qiu and Yi Shi for the code contribution during their internship in MEGVII Technology.

- Thanks Xin Jin for the discussion about DKD.