Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/apple/ml-aim

This repository provides the code and model checkpoints of the research paper: Scalable Pre-training of Large Autoregressive Image Models
https://github.com/apple/ml-aim

jax large-scale-vision-models mlx pytorch

Last synced: about 20 hours ago
JSON representation

This repository provides the code and model checkpoints of the research paper: Scalable Pre-training of Large Autoregressive Image Models

Awesome Lists containing this project

README

        

# AIM: Autoregressive Image Models

*Alaaeldin El-Nouby, Michal Klein, Shuangfei Zhai, Miguel Angel Bautista, Alexander Toshev, Vaishaal Shankar,
Joshua M Susskind, and Armand Joulin*

To appear at ICML 2024

[[`Paper`](https://arxiv.org/abs/2401.08541)] [[`BibTex`](#citation)]

This software project accompanies the research paper, [Scalable Pre-training of Large Autoregressive Image Models](https://arxiv.org/abs/2401.08541).

We introduce **AIM** a collection of vision models pre-trained with an autoregressive generative objective.
We show that autoregressive pre-training of image features exhibits similar scaling properties to their
textual counterpart (i.e. Large Language Models). Specifically, we highlight two findings:
1. the model capacity can be trivially scaled to billions of parameters, and
2. AIM effectively leverages large collections of uncurated image data.

## Installation
Please install PyTorch using the official [installation instructions](https://pytorch.org/get-started/locally/).
Afterward, install the package as:
```commandline
pip install git+https://[email protected]/apple/ml-aim.git
```
We also offer [MLX](https://github.com/ml-explore/mlx) backend support for research and experimentation on Apple silicon.
To enable MLX support, simply run:
```commandline
pip install mlx
```

## Usage
Below we provide an example of usage in [PyTorch](https://pytorch.org/):
```python
from PIL import Image

from aim.utils import load_pretrained
from aim.torch.data import val_transforms

img = Image.open(...)
model = load_pretrained("aim-600M-2B-imgs", backend="torch")
transform = val_transforms()

inp = transform(img).unsqueeze(0)
logits, features = model(inp)
```

and in both MLX

```python
from PIL import Image
import mlx.core as mx

from aim.utils import load_pretrained
from aim.torch.data import val_transforms

img = Image.open(...)
model = load_pretrained("aim-600M-2B-imgs", backend="mlx")
transform = val_transforms()

inp = transform(img).unsqueeze(0)
inp = mx.array(inp.numpy())
logits, features = model(inp)
```

and JAX

```python
from PIL import Image
import jax.numpy as jnp

from aim.utils import load_pretrained
from aim.torch.data import val_transforms

img = Image.open(...)
model, params = load_pretrained("aim-600M-2B-imgs", backend="jax")
transform = val_transforms()

inp = transform(img).unsqueeze(0)
inp = jnp.array(inp)
(logits, features), _ = model.apply(params, inp, mutable=['batch_stats'])
```

## Pre-trained checkpoints

The pre-trained models can be accessed via [PyTorch Hub](https://pytorch.org/hub/) as:
```python
import torch

aim_600m = torch.hub.load("apple/ml-aim", "aim_600M")
aim_1b = torch.hub.load("apple/ml-aim", "aim_1B")
aim_3b = torch.hub.load("apple/ml-aim", "aim_3B")
aim_7b = torch.hub.load("apple/ml-aim", "aim_7B")
```
or via [HuggingFace Hub](https://huggingface.co/docs/hub/) as:
```python
from aim.torch.models import AIMForImageClassification

aim_600m = AIMForImageClassification.from_pretrained("apple/aim-600M")
aim_1b = AIMForImageClassification.from_pretrained("apple/aim-1B")
aim_3b = AIMForImageClassification.from_pretrained("apple/aim-3B")
aim_7b = AIMForImageClassification.from_pretrained("apple/aim-7B")
```

### Pre-trained backbones

The following table contains pre-trained backbones used in our paper.



model
#params
attn (best layer)
backbone, SHA256




AIM-0.6B
0.6B
79.4%
link, 0d6f6b8f


AIM-1B
1B
82.3%
link, d254ecd3


AIM-3B
3B
83.3%
link, 8475ce4e


AIM-7B
7B
84.0%
link, 184ed94c

### Pre-trained attention heads

The table below contains the classification results on ImageNet-1k validation set.



model
top-1 IN-1k
attention head, SHA256


last layer
best layer
last layer
best layer



AIM-0.6B
78.5%
79.4%
link, 5ce5a341
link, ebd45c05


AIM-1B
80.6%
82.3%
link, db3be2ad
link, f1ed7852


AIM-3B
82.2%
83.3%
link, 5c057b30
link, ad380e16


AIM-7B
82.4%
84.0%
link, 1e5c99ba
link, 73ecd732

## Reproducing the IN-1k classification results
The commands below reproduce the [attention probe results](#pre-trained-attention-heads) on ImageNet-1k
validation set. We run the evaluation using 1 node with 8 GPUs:
```commandline
torchrun --standalone --nnodes=1 --nproc-per-node=8 main_attnprobe.py \
--model=aim-7B \
--batch-size=64 \
--data-path=/path/to/imagenet \
--probe-layers=best \
--backbone-ckpt-path=/path/to/backbone_ckpt.pth \
--head-ckpt-path=/path/to/head_ckpt.pth
```
By default, we probe features from the intermediate 6 layers that provide the best performance. To change this, simply pass `--probe-layers=last`.

## Citation
If you find our work useful, please consider citing us as:
```
@article{el2024scalable,
title={Scalable Pre-training of Large Autoregressive Image Models},
author={El-Nouby, Alaaeldin and Klein, Michal and Zhai, Shuangfei and Bautista, Miguel Angel and Toshev, Alexander and Shankar, Vaishaal and Susskind, Joshua M and Joulin, Armand},
journal={International Conference on Machine Learning},
year={2024}
}
```