https://github.com/MasterSkepticista/detr
Implementation of DETR in JAX
https://github.com/MasterSkepticista/detr
detr flax jax
Last synced: 4 months ago
JSON representation
Implementation of DETR in JAX
- Host: GitHub
- URL: https://github.com/MasterSkepticista/detr
- Owner: MasterSkepticista
- Created: 2024-04-29T09:11:27.000Z (about 1 year ago)
- Default Branch: main
- Last Pushed: 2024-05-28T10:12:51.000Z (about 1 year ago)
- Last Synced: 2024-05-29T01:50:14.841Z (about 1 year ago)
- Topics: detr, flax, jax
- Language: Python
- Homepage:
- Size: 355 KB
- Stars: 0
- Watchers: 2
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
- trackawesomelist - DETR (⭐2) - Flax implementation of [*DETR: End-to-end Object Detection with Transformers*](https://github.com/facebookresearch/detr) using Sinkhorn solver and parallel bipartite matching. (Recently Updated / [Feb 16, 2025](/content/2025/02/16/README.md))
README
# DETR: End-to-End Object Detection with Transformers.
This is a minimal implementation of [DETR](https://arxiv.org/abs/2005.12872) using `jax` and `flax`.
![]()
![]()
![]()
Features:
* Supports Flash Attention (up to 50% faster over and above the following optimizations).
* Supports Sinkhorn solver (faster training for roughly the same final AP).
* Parallel bipartite matching for all auxiliary outputs (up to 30% faster training using Hungarian matcher).
* Uses `optax` API.
* Bug fixes from scenic to match official DETR implementation.
* Supports BigTransfer (BiT-S) ResNet-50 backbone.You can read more about these optimizations [here](https://masterskepticista.github.io/portfolio/detr).
### Installation
* Setup:
```shell
git clone https://github.com/MasterSkepticista/detr.git && cd detr
python3.10 -m venv venv && source venv/bin/activate
pip install -U pip setuptools wheel
pip install -r requirements.txt
```* You may need to download MS-COCO dataset in TFDS. Run the following to download
and create TFRecords:
```shell
python -c "import tensorflow_datasets as tfds; tfds.load('coco/2017')"
```### Train
Download torch resnet50 checkpoint from GDrive.
```shell
pip install gdown# gdown -O
gdown 1q-PYc6ZshX12Nelb30V6Cp1FkmxhUdD2 -O artifacts/
```|Backbone|Top-1 Acc.|Checkpoint|
|--------|----------|----|
|BiT-R50x1-i1k|76.8%|[Link](https://drive.google.com/file/d/1iVBV9jghBR2mseSc5z2SB1b8QptI9mju/view?usp=drive_link)|
|R50x1-i1k (from torchvision)|76.1%|[Link](https://drive.google.com/file/d/1q-PYc6ZshX12Nelb30V6Cp1FkmxhUdD2/view?usp=sharing) (created using this [gist](https://gist.github.com/MasterSkepticista/c854bce837a5cb5ca0489bd33b3a2259))|```shell
# Trains the default DETR-R50-1333 model using `float32` precision.
# ~4.5 days on 8x A6000.
python main.py \
--config configs/hungarian.py --workdir artifacts/`date '+%m-%d_%H%M'`
```### Evaluate
Checkpoints (all non-DC5 variants) using the torchvision R50 backbone:|Checkpoint|GFLOPs|$AP$|$AP_{50}$|$AP_{75}$|$AP_S$|$AP_M$|$AP_L$|
|-|-|-|-|-|-|-|-|
[DETR-R50-1333*](https://drive.google.com/file/d/1fu4M3l88mhiQEUpADoUT2wrSEIZNDSqe/view?usp=sharing)|174.2|40.80|61.88|42.45|19.2|44.31|60.32|
[DETR-R50-640](https://drive.google.com/file/d/1XYV3ULIDwa59AVYSAvBeIOFXwRR_GZ46/view?usp=sharing)|38.5|33.14|52.89|34.00|10.54|35.10|55.53|\*official DETR baseline, except that these models were trained for 300 epochs instead of 500 epochs.
1. Download one of the pretrained checkpoints.
```python
# In configs/common.py (or any)
config.init_from = ml_collections.ConfigDict()
config.init_from.checkpoint_path = '/path/to/checkpoint'
```
2. Replace `config.total_epochs` with `config.total_steps = 0` to skip to eval.### Acknowledgements
Parts of this codebase are based on [scenic](https://github.com/google-research/scenic/).DETR implementation in PyTorch: [facebookresearch/detr](https://github.com/facebookresearch/detr).
### License
MIT