https://github.com/headless-start/peft-lora-vit
This repository contains LoRA fine-tuning of a Vision Transformer on Oxford-IIIT Pets and Flowers-102.
https://github.com/headless-start/peft-lora-vit
computer-vision deep-learning fine-tuning hydra image-classification lora peft python pytorch timm transfer-learning vision-transformer vit weights-and-biases
Last synced: 2 days ago
JSON representation
This repository contains LoRA fine-tuning of a Vision Transformer on Oxford-IIIT Pets and Flowers-102.
- Host: GitHub
- URL: https://github.com/headless-start/peft-lora-vit
- Owner: headless-start
- License: mit
- Created: 2026-05-30T14:27:01.000Z (18 days ago)
- Default Branch: main
- Last Pushed: 2026-06-10T11:41:42.000Z (7 days ago)
- Last Synced: 2026-06-10T13:18:23.167Z (7 days ago)
- Topics: computer-vision, deep-learning, fine-tuning, hydra, image-classification, lora, peft, python, pytorch, timm, transfer-learning, vision-transformer, vit, weights-and-biases
- Language: Python
- Size: 1.21 MB
- Stars: 1
- Watchers: 0
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Parameter-Efficient Fine-Tuning of a Vision Transformer (LoRA)
## ๐ Project Overview
This project demonstrates **parameter-efficient fine-tuning** of a **Vision Transformer (ViT-B/16)** for image classification using **LoRA** โ strictly low-rank updates, no other PEFT method. An ImageNet-pretrained backbone is adapted to a new dataset by learning small low-rank deltas on the attention **query/value** projections, while the backbone itself stays frozen. This reaches near full fine-tuning accuracy while updating only a tiny fraction of the weights.
**Datasets**: Oxford-IIIT Pets (37 cat and dog breeds) and Oxford Flowers-102.
**Backbone**: `vit_base_patch16_224`, pretrained on ImageNet via `timm`.
**Goal**: Strong top-1 accuracy while training well under 5% of the model's parameters.
I built this as hands-on preparation for the PEFT/LoRA side of my thesis; everything here is a standalone prototype on public data and public weights.

---
## ๐ Key Features
1. **Hand-Written LoRA**:
- Low-rank matrices injected into the fused q/v attention projections (`B ยท A ยท x ยท ฮฑ/r`, with `ฮฑ = 2r` and `B` zero-initialised so training starts exactly from the pretrained model).
- Placement follows the original [LoRA paper (Hu et al., 2022)](https://arxiv.org/abs/2106.09685), whose placement study (ยง7.1) found adapting **q and v** the best use of a fixed parameter budget โ k contributes least.
- Only the LoRA matrices and the classifier head are trainable; the backbone is fully frozen.
2. **Rank Ablation**:
- One command sweeps the LoRA rank over {4, 8, 16, 32} and plots accuracy and cost against rank.
3. **Tiny Checkpoints**:
- Only the LoRA weights and head are saved โ a few MB instead of the full 344 MB backbone. Inference rebuilds the model from public pretrained weights and loads the LoRA weights on top.
4. **Solid Training Recipe**:
- AdamW with a 2-epoch linear warmup into cosine decay, drop-path 0.1, mixed precision.
5. **Configurable with Hydra**:
- Data, model, and training settings live in `configs/` and can be overridden straight from the command line.
6. **Experiment Tracking**:
- Metrics are logged to Weights & Biases in **offline** mode by default, so it runs without an account.
---
## ๐ Findings
- **Top-1 Accuracy**: **95.2%** on the Pets validation set (weighted average recall, WAR), best run with rank 8 on q/v.
- **Trainable Parameters**: 323K out of 86.1M โ just **0.38%** of the model.
- **Setup**: LoRA rank 8 on q/v, 25 epochs, AdamW with warmup + cosine decay, mixed precision.
- **Takeaway**: LoRA matches โ here slightly beats โ full fine-tuning while training under half a percent of the weights.

### Baselines: how much does LoRA actually buy?
The comparison that matters: LoRA against a frozen-backbone **linear probe** (lower bound) and **full fine-tuning** (upper bound), all under the same protocol:
| method | top-1 acc (WAR) | trainable params | checkpoint | s/epoch | peak VRAM |
|------------------|-----------------|------------------|------------|---------|-----------|
| linear probe | 93.5% | 28K (0.03%) | 0.1 MB | 20 | 0.7 GB |
| LoRA r=8 (ours) | **94.9%** | 323K (0.38%) | 1.2 MB | 31 | 3.7 GB |
| full fine-tuning | 93.9% | 85.8M (100%) | 327 MB | 41 | 2.5 GB* |
\* full fine-tuning runs at batch 16 (the others at 64) to fit optimizer states for all 86M parameters into 8 GB โ its per-sample memory is far higher.
LoRA beats the linear probe by **+1.4 points**, so the frozen features alone are not enough โ and it even edges out full fine-tuning by **+1.0** while training **265ร fewer parameters** with a **270ร smaller checkpoint**. On a 3.7K-image dataset, updating all 86M weights overfits where the low-rank update acts as a regulariser; this mirrors the LoRA paper, which reports LoRA matching or outperforming full fine-tuning on most benchmarks.

### Placement Ablation
Which projections should carry the LoRA update? Sweeping every q/k/v subset at rank 8:
| placement | top-1 acc (WAR) | trainable params | % of total |
|-----------|-----------------|------------------|------------|
| q | 94.3% | 176K | 0.21% |
| k | 94.1% | 176K | 0.21% |
| v | 94.7% | 176K | 0.21% |
| q + k | 94.3% | 323K | 0.38% |
| q + v | **94.9%** | 323K | 0.38% |
| q + k + v | 94.7% | 471K | 0.55% |
**q + v wins.** k is the weakest single placement and adding it to q+v helps nothing โ q and k only shape the attention pattern through their inner product, so adapting q already covers it, while v changes the content being mixed and is complementary. This reproduces the placement study in [the LoRA paper](https://arxiv.org/abs/2106.09685) (ยง7.1, Table 5).

### Rank Ablation
With placement fixed at q+v, sweeping the rank shows accuracy saturates almost immediately โ rank 4 is already within 0.1 points of the best, and rank 32 buys nothing for 7ร the parameters:
| rank | top-1 acc (WAR) | trainable params | % of total |
|------|-----------------|------------------|------------|
| 4 | 94.8% | 176K | 0.21% |
| 8 | 94.9% | 323K | 0.38% |
| 16 | 94.6% | 618K | 0.72% |
| 32 | 94.9% | 1.21M | 1.39% |

Ablation numbers are single runs with the default recipe; reruns move individual cells by ยฑ0.3 points. The repo default (rank 8 on q/v) is the configuration both sweeps select.
---
## โ๏ธ How to Run
Works on Linux, macOS and Windows.
```bash
git clone https://github.com/headless-start/peft-lora-vit.git
cd peft-lora-vit
python -m venv .venv
source .venv/bin/activate # linux / macos
# .venv\Scripts\activate # windows
pip install -r requirements.txt
```
For GPU training install the CUDA build of PyTorch from [pytorch.org](https://pytorch.org/get-started/locally/) first; the plain `pip install` gives you a CPU build on some platforms.
```bash
# full run on Oxford-IIIT Pets (downloads on first use)
python train.py
# or train on Flowers-102 instead
python train.py data=flowers
# override anything from the command line
python train.py train.epochs=40 data.batch_size=32 model.lora.r=16
```
Sweep the LoRA rank (writes `results/ablation.json` and `results/ablation.png`):
```bash
python ablate.py # ranks 4, 8, 16, 32
python ablate.py --ranks 4,8 data=flowers
```
Compare LoRA against the baselines โ linear probe and full fine-tuning (writes `results/baselines.json` and `results/baselines.png`):
```bash
python baselines.py
```
Classify your own images with a trained checkpoint:
```bash
python predict.py path/to/cat.jpg path/to/dog.jpg
# path/to/cat.jpg: Abyssinian (100.0%), Russian Blue (0.0%), Shiba Inu (0.0%)
```
Quick smoke test (CPU, small backbone, no downloads):
```bash
python train.py +experiment=smoke
```
Runs are logged to Weights & Biases offline by default; to sync to the cloud:
```bash
wandb login
python train.py wandb.mode=online
```
Training curves and `metrics.json` are written to `results/`; checkpoints go to `outputs/`.
---
## ๐ System Requirements
### Dependencies
- Python 3.10+
- Libraries: `torch`, `torchvision`, `timm`, `hydra-core`, `wandb`, `matplotlib`
- Hardware: CUDA GPU recommended (a CPU smoke run is supported)
### Reproducibility
- Runs on Linux, macOS and Windows; all paths and commands are OS-agnostic.
- Seeds are fixed (`seed: 42`). Reported numbers came from Python 3.13, `torch` 2.12, `torchvision` 0.27, `timm` 1.0.27 on a single RTX 4060; expect individual cells to move by ยฑ0.3 points across reruns and library versions due to GPU non-determinism.
- On machines with little RAM, add `data.num_workers=0` to any command.
---
## ๐ License
This project is licensed under the MIT License. See the [LICENSE](LICENSE) file for details.