Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/Sharpiless/yolov5-knowledge-distillation
yolov5目标检测模型的知识蒸馏(基于响应的蒸馏)
https://github.com/Sharpiless/yolov5-knowledge-distillation
knowledge-distillation object-detection yolo yolov5
Last synced: 3 months ago
JSON representation
yolov5目标检测模型的知识蒸馏(基于响应的蒸馏)
- Host: GitHub
- URL: https://github.com/Sharpiless/yolov5-knowledge-distillation
- Owner: Sharpiless
- License: gpl-3.0
- Created: 2021-08-11T05:32:19.000Z (over 3 years ago)
- Default Branch: main
- Last Pushed: 2021-08-11T05:35:31.000Z (over 3 years ago)
- Last Synced: 2024-08-02T01:19:32.545Z (7 months ago)
- Topics: knowledge-distillation, object-detection, yolo, yolov5
- Language: Python
- Homepage:
- Size: 113 KB
- Stars: 89
- Watchers: 2
- Forks: 16
- Open Issues: 3
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
- awesome-yolo-object-detection - Sharpiless/yolov5-knowledge-distillation - knowledge-distillation?style=social"/> : yolov5目标检测模型的知识蒸馏(基于响应的蒸馏)。 (Lighter and Deployment Frameworks)
- awesome-yolo-object-detection - Sharpiless/yolov5-knowledge-distillation - knowledge-distillation?style=social"/> : yolov5目标检测模型的知识蒸馏(基于响应的蒸馏)。 (Lighter and Deployment Frameworks)
README
# 代码地址:
[https://github.com/Sharpiless/yolov5-knowledge-distillation](https://github.com/Sharpiless/yolov5-knowledge-distillation)
# 教师模型:
```bash
python train.py --weights weights/yolov5m.pt \
--cfg models/yolov5m.yaml --data data/voc.yaml --epochs 50 \
--batch-size 8 --device 0 --hyp data/hyp.scratch.yaml
```# 蒸馏训练:
```bash
python train.py --weights weights/yolov5s.pt \
--cfg models/yolov5s.yaml --data data/voc.yaml --epochs 50 \
--batch-size 8 --device 0 --hyp data/hyp.scratch.yaml \
--t_weights yolov5m.pt --distill
```# 训练参数:
> --weights:预训练模型
> --t_weights:教师模型权重
> --distill:使用知识蒸馏进行训练
> --dist_loss:l2或者kl
> --temperature:使用知识蒸馏时的温度
使用[《Object detection at 200 Frames Per Second》](https://arxiv.org/pdf/1805.06361.pdf)中的损失
这篇文章分别对这几个损失函数做出改进,具体思路为只有当teacher network的objectness value高时,才学习bounding box坐标和class probabilities。
data:image/s3,"s3://crabby-images/c9153/c915351edb18994893dd26c9faaa9f18c625b4d8" alt=""
data:image/s3,"s3://crabby-images/c821c/c821c7dd8ed33e74b4fe378364ceabbdcc34a404" alt=""
# 实验结果:
这里假设VOC2012中新增加的数据为无标签数据(2k张)。
|教师模型|训练方法|蒸馏损失|P|R|mAP50|
|:----|:----|:----|:----|:----|:----|
|无|正常训练|不使用|0.7756|0.7115|0.7609|
|Yolov5l|output based|l2|0.7585|0.7198|0.7644|
|Yolov5l|output based|KL|0.7417|0.7207|0.7536|
|Yolov5m|output based|l2|0.7682|0.7436|0.7976|
|Yolov5m|output based|KL|0.7731|0.7313|0.7931|data:image/s3,"s3://crabby-images/98bdb/98bdb13754c11d9bdfbf21eb5b65bddf2e5208d0" alt="训练结果"
参数和细节正在完善,支持KL散度、L2 logits损失和Sigmoid蒸馏损失等
## 1. 正常训练:
data:image/s3,"s3://crabby-images/daf92/daf92b2c024d687b2e36b9a6ee9b2473cc63d487" alt="正常训练"
## 2. L2蒸馏损失:
data:image/s3,"s3://crabby-images/481c2/481c243b041f41030b12ac4ef368c4e9c5c6f210" alt="L2蒸馏损失"
# 我的公众号:
data:image/s3,"s3://crabby-images/82271/82271452ea73adfe1f26aa865b03de05f286331f" alt="在这里插入图片描述"
# 关于作者
> B站:[https://space.bilibili.com/470550823](https://space.bilibili.com/470550823)> CSDN:[https://blog.csdn.net/weixin_44936889](https://blog.csdn.net/weixin_44936889)
> AI Studio:[https://aistudio.baidu.com/aistudio/personalcenter/thirdview/67156](https://aistudio.baidu.com/aistudio/personalcenter/thirdview/67156)
> Github:[https://github.com/Sharpiless](https://github.com/Sharpiless)