An open API service indexing awesome lists of open source software.

https://github.com/tonylianlong/crossmae

Official Implementation of the CrossMAE paper: Rethinking Patch Dependence for Masked Autoencoders
https://github.com/tonylianlong/crossmae

computer-vision deep-learning mae masked-autoencoder self-supervised-learning

Last synced: 25 days ago
JSON representation

Official Implementation of the CrossMAE paper: Rethinking Patch Dependence for Masked Autoencoders

Awesome Lists containing this project

README

          

## CrossMAE: Rethinking Patch Dependence for Masked Autoencoders
by Letian Fu*, Long Lian*, Renhao Wang, Baifeng Shi, Xudong Wang, Adam Yala†, Trevor Darrell†, Alexei A. Efros†, Ken Goldberg† at UC Berkeley and UCSF

[[Paper](https://openreview.net/forum?id=JT2KMuo2BV)] | [[Project Page](https://crossmae.github.io/)] | [[Citation](#citation)]



This is a PyTorch implementation of the CrossMAE paper [Rethinking Patch Dependence for Masked Autoencoders](https://crossmae.github.io/). The code is based on the original [MAE](https://github.com/facebookresearch/mae) repo. The codebase supports CrossMAE and MAE, with `timm==0.9.7`, `torch==2.0.0`, and flash-attn 2.

## Models
The encoder part of CrossMAE matches exactly with MAE. Therefore, we use the same code for fine-tuning. We also encourage you to try CrossMAE checkpoints in your downstream applications. These models are trained on ImageNet-1k for 800 epochs (except that 448 models are trained for 400 epochs), with masking ratio and kept mask ratio both set to 0.75, except that ViT-H is with masking ratio 0.75 and kept mask ratio 0.25.

ViT-Small
ViT-Base
ViT-Base448
ViT-Large
ViT-Huge

pretrained checkpoint
download
download
download
download
download

fine-tuned checkpoint
download
download
download
download
download

Reference ImageNet accuracy (ours)
79.318
83.722
84.598
85.432
86.256

MAE ImageNet accuracy (baseline)

84.8

85.9

## Train CrossMAE on **one single RTX 4090**
With the efficiency of CrossMAE, it's possible to train CrossMAE on **one single RTX 4090** on a personal computer. The CPU is i9-14900k, with 96GB RAM.

Instructions and trained models

The training and fine-tuning command (with `${IMAGENET_DIR}` the directory for imagenet, ViT-S as an example):
```sh
CUDA_VISIBLE_DEVICES=0 OMP_NUM_THREADS=1 torchrun --nproc_per_node=1 --master_port 2780 main_pretrain.py --batch_size 512 --accum_iter 8 --model mae_vit_small_patch16 --norm_pix_loss --blr 1.5e-4 --weight_decay 0.05 --data_path ${IMAGENET_DIR} --num_workers 16 --multi_epochs_dataloader --output_dir output/imagenet-crossmae-vits-pretrain-wfm-mr0.75-kmr0.25-dd12-ep800 --cross_mae --weight_fm --decoder_depth 12 --mask_ratio 0.75 --kept_mask_ratio 0.75 --epochs 800 --warmup_epochs 40 --use_input

CUDA_VISIBLE_DEVICES=0 OMP_NUM_THREADS=1 torchrun --nproc_per_node=1 --master_port 2860 main_finetune.py --batch_size 512 --accum_iter 2 --model vit_small_patch16 --finetune output/imagenet-crossmae-vits-pretrain-wfm-mr0.75-kmr0.25-dd12-ep800/checkpoint.pth --epoch 100 --blr 5e-4 --layer_decay 0.65 --weight_decay 0.05 --drop_path 0.1 --reprob 0.25 --mixup 0.8 --cutmix 1.0 --dist_eval --data_path ${IMAGENET_DIR} --num_workers 12 --output_dir output/imagenet-crossmae-vits-finetune-wfm-mr0.75-kmr0.25-dd12-ep800 --multi_epochs_dataloader
# Reference results:
# * Acc@1 79.462 Acc@5 94.864 loss 0.907
```




pretrained checkpoint
fine-tuned checkpoint
reference ImageNet accuracy



download
download
79.462

## Instructions
Please install the dependencies in `requirements.txt`:
```sh
# Optionally create a conda environment
conda create -n crossmae python=3.10 -y
conda activate crossmae
# Install dependencies
pip install -r requirements.txt
```

### Pre-training CrossMAE
To pre-train ViT-Base, run the following on 4 GPUs:
```sh
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port 1234 main_pretrain.py --batch_size 1024 --model mae_vit_base_patch16 --norm_pix_loss --blr 1.5e-4 --weight_decay 0.05 --data_path ${IMAGENET_DIR} --num_workers 20 --enable_flash_attention2 --multi_epochs_dataloader --output_dir output/imagenet-crossmae-vitb-pretrain-wfm-mr0.75-kmr0.25-dd12-ep800 --cross_mae --weight_fm --decoder_depth 12 --mask_ratio 0.75 --kept_mask_ratio 0.25 --epochs 800 --warmup_epochs 40 --use_input
```

To train ViT-Small or ViT-Large, set `--model mae_vit_small_patch16` or `--model mae_vit_large_patch16`. You can use `--accum_iter` to perform gradient accumulation if your hardware could not fit the batch size. [FlashAttention 2](https://github.com/Dao-AILab/flash-attention) should be installed with `pip install flash-attn --no-build-isolation`.

### Fine-tuning CrossMAE
To pre-train ViT-Base, run the following on 4 GPUs:
```sh
CUDA_VISIBLE_DEVICES=0,1,2,3 torchrun --nproc_per_node=4 --master_port 1234 main_finetune.py --batch_size 256 --model vit_base_patch16 --finetune output/imagenet-crossmae-vitb-pretrain-wfm-mr0.75-kmr0.25-dd12-ep800/checkpoint.pth --epoch 100 --blr 5e-4 --layer_decay 0.65 --weight_decay 0.05 --drop_path 0.1 --reprob 0.25 --mixup 0.8 --cutmix 1.0 --dist_eval --data_path ${IMAGENET_DIR} --output_dir output/imagenet-crossmae-vitb-finetune-wfm-mr0.75-kmr0.25-dd12-ep800 --enable_flash_attention2 --multi_epochs_dataloader
```

## Evaluation
Evaluate ViT-Base in a single GPU (`${IMAGENET_DIR}` is a directory containing `{train, val}` sets of ImageNet). `${FINETUNED_CHECKPOINT_PATH}` is the path to the fine-tuned checkpoint:
```sh
python main_finetune.py --eval --resume ${FINETUNED_CHECKPOINT_PATH} --model vit_base_patch16 --batch_size 16 --data_path ${IMAGENET_DIR}
```
This should give:
```
* Acc@1 83.722 Acc@5 96.686 loss 0.729
```

You could replace `vit_base_patch16` with `vit_small_patch16` or `vit_large_patch16` to evaluate ViT-S or ViT-L. To work with 448 input resolution, please append `--input_size 448` to the command line.

### License

This project is under the CC-BY-NC 4.0 license. See [LICENSE](LICENSE) for details.

## Citation
Please give us a star 🌟 on Github to support us!

Please cite our work if you find our work inspiring or use our code in your work:
```
@article{
fu2025rethinking,
title={Rethinking Patch Dependence for Masked Autoencoders},
author={Letian Fu and Long Lian and Renhao Wang and Baifeng Shi and XuDong Wang and Adam Yala and Trevor Darrell and Alexei A Efros and Ken Goldberg},
journal={Transactions on Machine Learning Research},
issn={2835-8856},
year={2025},
url={https://openreview.net/forum?id=JT2KMuo2BV},
note={}
}
```