
An open API service indexing awesome lists of open source software.

[ICLR'23 Spotlight🔥] The first successful BERT/MAE-style pretraining on any convolutional network; Pytorch impl. of "Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling"

bert cnn convnet convolutional-neural-networks deep-learning iclr iclr2023 instance-segmentation mae mask-rcnn masked-autoencoder masked-image-modeling object-detection pre-trained-model pretrain pretraining pytorch self-supervised-learning sparse-convolution ssl

Last synced: 4 months ago
JSON representation

[ICLR'23 Spotlight🔥] The first successful BERT/MAE-style pretraining on any convolutional network; Pytorch impl. of "Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling"




# SparK: the first successful BERT/MAE-style pretraining on *any* convolutional networks  [![Reddit](🔥%20120k%20views-b31b1b.svg?style=social&logo=reddit)]( [![Twitter](🔥%2020k%2B120k%20views-b31b1b.svg?style=social&logo=twitter)](

This is the official implementation of ICLR paper [Designing BERT for Convolutional Networks: ***Spar***se and Hierarchical Mas***k***ed Modeling](, which can pretrain **any CNN** (e.g., ResNet) in a **BERT-style self-supervised** manner.
We've tried our best to make the codebase clean, short, easy to read, state-of-the-art, and only rely on minimal dependencies.


## 🔥 News

- A brief introduction (in English) is available on our ICLR poster page! [[`📹Recorded Video, Poster, and Slides`](].
- On **May. 11th** another livestream on OpenMMLab & ReadPaper (bilibili)! [[`📹Recorded Video`](]
- On **Apr. 27th (UTC+8 8pm)** another livestream would be held at [OpenMMLab (bilibili)](!
- On **Mar. 22nd (UTC+8 8pm)** another livestream would be held at 极市平台 (bilibili)! [[`📹Recorded Video`](]
- The share on [TechBeat (将门创投)]( is scheduled on **Mar. 16th (UTC+8 8pm)** too! [[`📹Recorded Video`](]
- We are honored to be invited by Synced ("机器之心机动组 视频号" on WeChat) to give a talk about SparK on **Feb. 27th (UTC+0 11am, UTC+8 7pm)**, welcome! [[`📹Recorded Video`](]
- This work got accepted to ICLR 2023 as a Spotlight (notable-top-25%).
- Other articles: [[`Synced`](]

## 🕹️ Colab Visualization Demo

Check [pretrain/viz_reconstruction.ipynb](pretrain/viz_reconstruction.ipynb) for visualizing the reconstruction of SparK pretrained models, like:

We also provide [pretrain/viz_spconv.ipynb](pretrain/viz_spconv.ipynb) that shows the "mask pattern vanishing" issue of dense conv layers.

## What's new here?

### 🔥 Pretrained CNN beats pretrained Swin-Transformer:

### 🔥 After SparK pretraining, smaller models can beat un-pretrained larger models:

### 🔥 All models can benefit, showing a scaling behavior:

### 🔥 Generative self-supervised pretraining surpasses contrastive learning:

#### See our [paper]( for more analysis, discussions, and evaluations.

## Todo list


- [x] Pretraining code
- [x] Pretraining toturial for customized CNN model ([Tutorial for pretraining your own CNN model](
- [x] Pretraining toturial for customized dataset ([Tutorial for pretraining your own dataset](
- [x] Pretraining Colab visualization playground ([reconstruction](/pretrain/viz_reconstruction.ipynb), [sparse conv](pretrain/viz_spconv.ipynb))
- [x] Finetuning code
- [ ] Weights & visualization playground in `huggingface`
- [ ] Weights in `timm`

## Pretrained weights (self-supervised; w/o decoder; can be directly finetuned)

**Note: for network definitions, we directly use `timm.models.ResNet` and [official ConvNeXt](**

`reso.`: the image resolution; `acc@1`: ImageNet-1K finetuned acc (top-1)

| arch. | reso. | acc@1 | #params | flops | weights (self-supervised, without SparK's decoder) |
| ResNet50 | 224 | 80.6 | 26M | 4.1G | [resnet50_1kpretrained_timm_style.pth]( |
| ResNet101 | 224 | 82.2 | 45M | 7.9G | [resnet101_1kpretrained_timm_style.pth]( |
| ResNet152 | 224 | 82.7 | 60M | 11.6G | [resnet152_1kpretrained_timm_style.pth]( |
| ResNet200 | 224 | 83.1 | 65M | 15.1G | [resnet200_1kpretrained_timm_style.pth]( |
| ConvNeXt-S | 224 | 84.1 | 50M | 8.7G | [convnextS_1kpretrained_official_style.pth]( |
| ConvNeXt-B | 224 | 84.8 | 89M | 15.4G | [convnextB_1kpretrained_official_style.pth]( |
| ConvNeXt-L | 224 | 85.4 | 198M | 34.4G | [convnextL_1kpretrained_official_style.pth]( |
| ConvNeXt-L | 384 | 86.0 | 198M | 101.0G | [convnextL_384_1kpretrained_official_style.pth]( |

Pretrained weights (with SparK's UNet-style decoder; can be used to reconstruct images)

| arch. | reso. | acc@1 | #params | flops | weights (self-supervised, with SparK's decoder) |
| ResNet50 | 224 | 80.6 | 26M | 4.1G | [res50_withdecoder_1kpretrained_spark_style.pth]( |
| ResNet101 | 224 | 82.2 | 45M | 7.9G | [res101_withdecoder_1kpretrained_spark_style.pth]( |
| ResNet152 | 224 | 82.7 | 60M | 11.6G | [res152_withdecoder_1kpretrained_spark_style.pth]( |
| ResNet200 | 224 | 83.1 | 65M | 15.1G | [res200_withdecoder_1kpretrained_spark_style.pth]( |
| ConvNeXt-S | 224 | 84.1 | 50M | 8.7G | [cnxS224_withdecoder_1kpretrained_spark_style.pth]( |
| ConvNeXt-L | 384 | 86.0 | 198M | 101.0G | [cnxL384_withdecoder_1kpretrained_spark_style.pth]( |

## Installation & Running

We highly recommended you to use `torch==1.10.0`, `torchvision==0.11.1`, and `timm==0.5.4` for reproduction.
Check []( to install all pip dependencies.

- **Loading pretrained model weights in 3 lines**
# download our weights `resnet50_1kpretrained_timm_style.pth` first
import torch, timm
res50, state = timm.create_model('resnet50'), torch.load('resnet50_1kpretrained_timm_style.pth', 'cpu')
res50.load_state_dict(state.get('module', state), strict=False) # just in case the model weights are actually saved in state['module']

- **Pretraining**
- any ResNet or ConvNeXt on ImageNet-1k:  see [pretrain/](pretrain)
- **your own CNN model**:  see [pretrain/](pretrain), especially [pretrain/models/](pretrain/models/

- **Finetuning**
- any ResNet or ConvNeXt on ImageNet-1k:  check [downstream_imagenet/](downstream_imagenet) for subsequent instructions.
- ResNets on COCO:  see [downstream_d2/](downstream_d2)
- ConvNeXts on COCO:  see [downstream_mmdet/](downstream_mmdet)

## Acknowledgement

We referred to these useful codebases:

- [BEiT](, [MAE](, [ConvNeXt](
- [timm](, [MoCoV2](, [Detectron2](, [MMDetection](

## License
This project is under the MIT license. See [LICENSE](LICENSE) for more details.

## Citation

If you found this project useful, you can kindly give us a star ⭐, or cite us in your work 📖:
author = {Keyu Tian and Yi Jiang and Qishuai Diao and Chen Lin and Liwei Wang and Zehuan Yuan},
title = {Designing BERT for Convolutional Networks: Sparse and Hierarchical Masked Modeling},
journal = {arXiv:2301.03580},
year = {2023},