An open API service indexing awesome lists of open source software.

https://github.com/headless-start/peft-lora-vit

This repository contains LoRA fine-tuning of a Vision Transformer on Oxford-IIIT Pets and Flowers-102.
https://github.com/headless-start/peft-lora-vit

computer-vision deep-learning fine-tuning hydra image-classification lora peft python pytorch timm transfer-learning vision-transformer vit weights-and-biases

Last synced: 2 days ago
JSON representation

This repository contains LoRA fine-tuning of a Vision Transformer on Oxford-IIIT Pets and Flowers-102.

Awesome Lists containing this project

README

          

# Parameter-Efficient Fine-Tuning of a Vision Transformer (LoRA)

## ๐Ÿ“Œ Project Overview
This project demonstrates **parameter-efficient fine-tuning** of a **Vision Transformer (ViT-B/16)** for image classification using **LoRA** โ€” strictly low-rank updates, no other PEFT method. An ImageNet-pretrained backbone is adapted to a new dataset by learning small low-rank deltas on the attention **query/value** projections, while the backbone itself stays frozen. This reaches near full fine-tuning accuracy while updating only a tiny fraction of the weights.

**Datasets**: Oxford-IIIT Pets (37 cat and dog breeds) and Oxford Flowers-102.
**Backbone**: `vit_base_patch16_224`, pretrained on ImageNet via `timm`.
**Goal**: Strong top-1 accuracy while training well under 5% of the model's parameters.

I built this as hands-on preparation for the PEFT/LoRA side of my thesis; everything here is a standalone prototype on public data and public weights.

![Dataset Samples](results/pet_samples.png)

---

## ๐Ÿš€ Key Features
1. **Hand-Written LoRA**:
- Low-rank matrices injected into the fused q/v attention projections (`B ยท A ยท x ยท ฮฑ/r`, with `ฮฑ = 2r` and `B` zero-initialised so training starts exactly from the pretrained model).
- Placement follows the original [LoRA paper (Hu et al., 2022)](https://arxiv.org/abs/2106.09685), whose placement study (ยง7.1) found adapting **q and v** the best use of a fixed parameter budget โ€” k contributes least.
- Only the LoRA matrices and the classifier head are trainable; the backbone is fully frozen.
2. **Rank Ablation**:
- One command sweeps the LoRA rank over {4, 8, 16, 32} and plots accuracy and cost against rank.
3. **Tiny Checkpoints**:
- Only the LoRA weights and head are saved โ€” a few MB instead of the full 344 MB backbone. Inference rebuilds the model from public pretrained weights and loads the LoRA weights on top.
4. **Solid Training Recipe**:
- AdamW with a 2-epoch linear warmup into cosine decay, drop-path 0.1, mixed precision.
5. **Configurable with Hydra**:
- Data, model, and training settings live in `configs/` and can be overridden straight from the command line.
6. **Experiment Tracking**:
- Metrics are logged to Weights & Biases in **offline** mode by default, so it runs without an account.

---

## ๐Ÿ” Findings
- **Top-1 Accuracy**: **95.2%** on the Pets validation set (weighted average recall, WAR), best run with rank 8 on q/v.
- **Trainable Parameters**: 323K out of 86.1M โ€” just **0.38%** of the model.
- **Setup**: LoRA rank 8 on q/v, 25 epochs, AdamW with warmup + cosine decay, mixed precision.
- **Takeaway**: LoRA matches โ€” here slightly beats โ€” full fine-tuning while training under half a percent of the weights.

![Training Curves](results/training_curve.png)

### Baselines: how much does LoRA actually buy?
The comparison that matters: LoRA against a frozen-backbone **linear probe** (lower bound) and **full fine-tuning** (upper bound), all under the same protocol:

| method | top-1 acc (WAR) | trainable params | checkpoint | s/epoch | peak VRAM |
|------------------|-----------------|------------------|------------|---------|-----------|
| linear probe | 93.5% | 28K (0.03%) | 0.1 MB | 20 | 0.7 GB |
| LoRA r=8 (ours) | **94.9%** | 323K (0.38%) | 1.2 MB | 31 | 3.7 GB |
| full fine-tuning | 93.9% | 85.8M (100%) | 327 MB | 41 | 2.5 GB* |

\* full fine-tuning runs at batch 16 (the others at 64) to fit optimizer states for all 86M parameters into 8 GB โ€” its per-sample memory is far higher.

LoRA beats the linear probe by **+1.4 points**, so the frozen features alone are not enough โ€” and it even edges out full fine-tuning by **+1.0** while training **265ร— fewer parameters** with a **270ร— smaller checkpoint**. On a 3.7K-image dataset, updating all 86M weights overfits where the low-rank update acts as a regulariser; this mirrors the LoRA paper, which reports LoRA matching or outperforming full fine-tuning on most benchmarks.

![Baselines](results/baselines.png)

### Placement Ablation
Which projections should carry the LoRA update? Sweeping every q/k/v subset at rank 8:

| placement | top-1 acc (WAR) | trainable params | % of total |
|-----------|-----------------|------------------|------------|
| q | 94.3% | 176K | 0.21% |
| k | 94.1% | 176K | 0.21% |
| v | 94.7% | 176K | 0.21% |
| q + k | 94.3% | 323K | 0.38% |
| q + v | **94.9%** | 323K | 0.38% |
| q + k + v | 94.7% | 471K | 0.55% |

**q + v wins.** k is the weakest single placement and adding it to q+v helps nothing โ€” q and k only shape the attention pattern through their inner product, so adapting q already covers it, while v changes the content being mixed and is complementary. This reproduces the placement study in [the LoRA paper](https://arxiv.org/abs/2106.09685) (ยง7.1, Table 5).

![Placement Ablation](results/placement.png)

### Rank Ablation
With placement fixed at q+v, sweeping the rank shows accuracy saturates almost immediately โ€” rank 4 is already within 0.1 points of the best, and rank 32 buys nothing for 7ร— the parameters:

| rank | top-1 acc (WAR) | trainable params | % of total |
|------|-----------------|------------------|------------|
| 4 | 94.8% | 176K | 0.21% |
| 8 | 94.9% | 323K | 0.38% |
| 16 | 94.6% | 618K | 0.72% |
| 32 | 94.9% | 1.21M | 1.39% |

![Rank Ablation](results/ablation.png)

Ablation numbers are single runs with the default recipe; reruns move individual cells by ยฑ0.3 points. The repo default (rank 8 on q/v) is the configuration both sweeps select.

---

## โš™๏ธ How to Run
Works on Linux, macOS and Windows.

```bash
git clone https://github.com/headless-start/peft-lora-vit.git
cd peft-lora-vit

python -m venv .venv
source .venv/bin/activate # linux / macos
# .venv\Scripts\activate # windows

pip install -r requirements.txt
```

For GPU training install the CUDA build of PyTorch from [pytorch.org](https://pytorch.org/get-started/locally/) first; the plain `pip install` gives you a CPU build on some platforms.

```bash
# full run on Oxford-IIIT Pets (downloads on first use)
python train.py

# or train on Flowers-102 instead
python train.py data=flowers

# override anything from the command line
python train.py train.epochs=40 data.batch_size=32 model.lora.r=16
```

Sweep the LoRA rank (writes `results/ablation.json` and `results/ablation.png`):

```bash
python ablate.py # ranks 4, 8, 16, 32
python ablate.py --ranks 4,8 data=flowers
```

Compare LoRA against the baselines โ€” linear probe and full fine-tuning (writes `results/baselines.json` and `results/baselines.png`):

```bash
python baselines.py
```

Classify your own images with a trained checkpoint:

```bash
python predict.py path/to/cat.jpg path/to/dog.jpg
# path/to/cat.jpg: Abyssinian (100.0%), Russian Blue (0.0%), Shiba Inu (0.0%)
```

Quick smoke test (CPU, small backbone, no downloads):

```bash
python train.py +experiment=smoke
```

Runs are logged to Weights & Biases offline by default; to sync to the cloud:

```bash
wandb login
python train.py wandb.mode=online
```

Training curves and `metrics.json` are written to `results/`; checkpoints go to `outputs/`.

---

## ๐Ÿ›  System Requirements
### Dependencies
- Python 3.10+
- Libraries: `torch`, `torchvision`, `timm`, `hydra-core`, `wandb`, `matplotlib`
- Hardware: CUDA GPU recommended (a CPU smoke run is supported)

### Reproducibility
- Runs on Linux, macOS and Windows; all paths and commands are OS-agnostic.
- Seeds are fixed (`seed: 42`). Reported numbers came from Python 3.13, `torch` 2.12, `torchvision` 0.27, `timm` 1.0.27 on a single RTX 4060; expect individual cells to move by ยฑ0.3 points across reruns and library versions due to GPU non-determinism.
- On machines with little RAM, add `data.num_workers=0` to any command.

---

## ๐Ÿ“„ License
This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.