https://github.com/damo-nlp-sg/digit
[NeurIPS 2024] Stabilize the Latent Space for Image Autoregressive Modeling: A Unified Perspective
https://github.com/damo-nlp-sg/digit
autoregressive fairseq gpt image-generation language-model neurips transformer
Last synced: 5 months ago
JSON representation
[NeurIPS 2024] Stabilize the Latent Space for Image Autoregressive Modeling: A Unified Perspective
- Host: GitHub
- URL: https://github.com/damo-nlp-sg/digit
- Owner: DAMO-NLP-SG
- License: mit
- Created: 2024-10-15T03:02:22.000Z (12 months ago)
- Default Branch: main
- Last Pushed: 2024-10-21T11:39:40.000Z (12 months ago)
- Last Synced: 2024-10-23T03:38:38.282Z (12 months ago)
- Topics: autoregressive, fairseq, gpt, image-generation, language-model, neurips, transformer
- Language: Python
- Homepage:
- Size: 14.5 MB
- Stars: 30
- Watchers: 5
- Forks: 2
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
Stabilize the Latent Space for Image Autoregressive Modeling: A Unified Perspective (NeurIPS 2024)
[](https://arxiv.org/abs/2410.12490)Â
[](https://paperswithcode.com/sota/image-generation-on-imagenet-256x256?tag_filter=485&p=stabilize-the-latent-space-for-image)
## Overview

We present **DiGIT**, an auto-regressive generative model performing next-token prediction in an abstract latent space derived from self-supervised learning (SSL) models. By employing K-Means clustering on the hidden states of the DINOv2 model, we effectively create a novel discrete tokenizer. This method significantly boosts image generation performance on ImageNet dataset, achieving an FID score of **4.59 for class-unconditional tasks** and **3.39 for class-conditional tasks**. Additionally, the model enhances image understanding, achieving a **linear-probe accuracy of 80.3**.
## Experimental Results
### Linear-Probe Accuracy on ImageNet
| Methods | \# Tokens | Features | \# Params | Top-1 Acc. $\uparrow$ |
|-----------------------------------|-------------|----------|------------|-----------------------|
| iGPT-L | 32 $\times$ 32 | 1536 | 1362M | 60.3 |
| iGPT-XL | 64 $\times$ 64 | 3072 | 6801M | 68.7 |
| VIM+VQGAN | 32 $\times$ 32 | 1024 | 650M | 61.8 |
| VIM+dVAE | 32 $\times$ 32 | 1024 | 650M | 63.8 |
| VIM+ViT-VQGAN | 32 $\times$ 32 | 1024 | 650M | 65.1 |
| VIM+ViT-VQGAN | 32 $\times$ 32 | 2048 | 1697M | 73.2 |
| AIM | 16 $\times$ 16 | 1536 | 0.6B | 70.5 |
| **DiGIT (Ours)** | 16 $\times$ 16 | 1024 | 219M | 71.7 |
| **DiGIT (Ours)** | 16 $\times$ 16 | 1536 | 732M | **80.3** |### Class-Unconditional Image Generation on ImageNet (Resolution: 256 $\times$ 256)
| Type | Methods | \# Param | \# Epoch | FID $\downarrow$ | IS $\uparrow$ |
|-------|-------------------------------------|----------|----------|------------------|----------------|
| GAN | BigGAN | 70M | - | 38.6 | 24.70 |
| Diff. | LDM | 395M | - | 39.1 | 22.83 |
| Diff. | ADM | 554M | - | 26.2 | 39.70 |
| MIM | MAGE | 200M | 1600 | 11.1 | 81.17 |
| MIM | MAGE | 463M | 1600 | 9.10 | 105.1 |
| MIM | MaskGIT | 227M | 300 | 20.7 | 42.08 |
| MIM | **DiGIT (+MaskGIT)** | 219M | 200 | **9.04** | **75.04** |
| AR | VQGAN | 214M | 200 | 24.38 | 30.93 |
| AR | **DiGIT (+VQGAN)** | 219M | 400 | **9.13** | **73.85** |
| AR | **DiGIT (+VQGAN)** | 732M | 200 | **4.59** | **141.29** |### Class-Conditional Image Generation on ImageNet (Resolution: 256 $\times$ 256)
| Type | Methods | \# Param | \# Epoch | FID $\downarrow$ | IS $\uparrow$ |
|-------|----------------------|----------|----------|------------------|----------------|
| GAN | BigGAN | 160M | - | 6.95 | 198.2 |
| Diff. | ADM | 554M | - | 10.94 | 101.0 |
| Diff. | LDM-4 | 400M | - | 10.56 | 103.5 |
| Diff. | DiT-XL/2 | 675M | - | 9.62 | 121.50 |
| Diff. | L-DiT-7B | 7B | - | 6.09 | 153.32 |
| MIM | CQR-Trans | 371M | 300 | 5.45 | 172.6 |
| MIM+AR | VAR | 310M | 200 | 4.64 | - |
| MIM+AR | VAR | 310M | 200 | 3.60* | 257.5* |
| MIM+AR | VAR | 600M | 250 | 2.95* | 306.1* |
| MIM | MAGVIT-v2 | 307M | 1080 | 3.65 | 200.5 |
| AR | VQVAE-2 | 13.5B | - | 31.11 | 45 |
| AR | RQ-Trans | 480M | - | 15.72 | 86.8 |
| AR | RQ-Trans | 3.8B | - | 7.55 | 134.0 |
| AR | ViTVQGAN | 650M | 360 | 11.20 | 97.2 |
| AR | ViTVQGAN | 1.7B | 360 | 5.3 | 149.9 |
| MIM | MaskGIT | 227M | 300 | 6.18 | 182.1 |
| MIM | **DiGIT (+MaskGIT)** | 219M | 200 | **4.62** | **146.19** |
| AR | VQGAN | 227M | 300 | 18.65 | 80.4 |
| AR | **DiGIT (+VQGAN)** | 219M | 400 | **4.79** | **142.87** |
| AR | **DiGIT (+VQGAN)** | 732M | 200 | **3.39** | **205.96** |*: VAR is trained with classifier-free guidance while all the other models are not.
## Checkpoints
The K-Means npy file and model checkpoints can be downloaded from:| Model | Link |
|:----------:|:-----:|
| HF weights🤗 | [Huggingface](https://huggingface.co/DAMO-NLP-SG/DiGIT) |For the base model we use [DINOv2-base](https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth) and [DINOv2-large](https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth) for large size model. The VQGAN we use is the same as [MAGE](https://drive.google.com/file/d/13S_unB87n6KKuuMdyMnyExW0G1kplTbP/view?usp=sharing).
```
DiGIT
└── data/
├── ILSVRC2012
├── dinov2_base_short_224_l3
├── km_8k.npy
├── dinov2_large_short_224_l3
├── km_16k.npy
└── outputs/
├── base_8k_stage1
├── ...
└── models/
├── vqgan_jax_strongaug.ckpt
├── dinov2_vitb14_reg4_pretrain.pth
├── dinov2_vitl14_reg4_pretrain.pth
```## Preparation
### Installation
1. Download the code
```shell
git clone https://github.com/DAMO-NLP-SG/DiGIT.git
cd DiGIT
```2. Install `fairseq` via `pip install fairseq`.
### Dataset Preparation
Download [ImageNet](http://image-net.org/) dataset, and place it in your dataset dir `$PATH_TO_YOUR_WORKSPACE/dataset/ILSVRC2012`.### Tokenizer
Extract SSL features and save them as .npy files. Use the K-Means algorithm with [faiss](https://github.com/facebookresearch/faiss) to compute the centroids. You can also utilize our pre-trained centroids available on [Huggingface](https://huggingface.co/DAMO-NLP-SG/DiGIT).```shell
bash preprocess/run.sh
```### Training Scripts
**Step1**
Train a GPT model with a discriminative tokenizer. You can find the training scripts in `scripts/train_stage1_ar.sh` and the hyper-params are in `config/stage1/dino_base.yaml`. For class conditional generation configuration, see `scripts/train_stage1_classcond.sh`.
**Step2**
Train a pixel decoder (either AR model or NAR model) conditioned on the discriminative tokens. You can find the autoregressive training scripts in `scripts/train_stage2_ar.sh` and NAR training scripts in `scripts/train_stage2_nar.sh`.
A folder named `outputs/EXP_NAME/checkpoints` will be created to save the checkpoints. TensorBoard log files are saved at `outputs/EXP_NAME/tb`. Logs will be recorded in `outputs/EXP_NAME/train.log`.
You can monitor the training process using `tensorboard --logdir=outputs/EXP_NAME/tb`.
### Sampling Scripts
First sampling discriminative tokens with `scripts/infer_stage1_ar.sh`. For the base model size, we recommend setting topk=200, and for a large model size, use topk=400.
Then run `scripts/infer_stage2_ar.sh` to sample VQ tokens based on the previously sampled discriminative tokens.
Generated tokens and synthesized images will be stored in a directory named `outputs/EXP_NAME/results`.
### FID and IS evaluation
Prepare the ImageNet validation set for FID evaluation:
```shell
python prepare_imgnet_val.py --data_path $PATH_TO_YOUR_WORKSPACE/dataset/ILSVRC2012 --output_dir imagenet-val
```Install the evaluation tool by running `pip install torch-fidelity`.
Execute the following command to evaluate FID:
```shell
python fairseq_user/eval_fid.py --results-path $IMG_SAVE_DIR --subset $GEN_SUBSET
```### Linear Probe training
```shell
bash scripts/train_stage1_linearprobe.sh
```## License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.## Citation
If you find our project useful, hope you can star our repo and cite our work as follows.
```bibtex
@misc{zhu2024stabilize,
title={Stabilize the Latent Space for Image Autoregressive Modeling: A Unified Perspective},
author={Yongxin Zhu and Bocheng Li and Hang Zhang and Xin Li and Linli Xu and Lidong Bing},
year={2024},
eprint={2410.12490},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```