https://github.com/anto18671/fairface-diffusion
A diffusion model trained on FairFace with discrete conditioning on age, gender, and race. Generates high-quality face images using a custom UNet and DDPM scheduler.
https://github.com/anto18671/fairface-diffusion
conditional-diffusion diffusion-models fairface-dataset
Last synced: 4 months ago
JSON representation
A diffusion model trained on FairFace with discrete conditioning on age, gender, and race. Generates high-quality face images using a custom UNet and DDPM scheduler.
- Host: GitHub
- URL: https://github.com/anto18671/fairface-diffusion
- Owner: anto18671
- License: mit
- Created: 2025-08-02T18:18:12.000Z (7 months ago)
- Default Branch: main
- Last Pushed: 2025-08-02T18:25:50.000Z (7 months ago)
- Last Synced: 2025-10-04T05:52:58.679Z (4 months ago)
- Topics: conditional-diffusion, diffusion-models, fairface-dataset
- Language: Python
- Homepage:
- Size: 894 KB
- Stars: 1
- Watchers: 0
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# π§ FairFace Conditioned Diffusion Model
This project trains a conditional diffusion model on the [FairFace](https://huggingface.co/datasets/HuggingFaceM4/FairFace) dataset using discrete attributes: **age**, **gender**, and **race**.
Each sample is conditioned on one of `9 Γ 2 Γ 7 = 126` unique combinations from the dataset's categorical metadata.
---
## π Conditioning Setup
We compute a single integer condition ID from:
- **Age**: 9 classes
- **Gender**: 2 classes
- **Race**: 7 classes
The combined ID is:
```
condition_id = (age_id * 14) + (gender_id * 7) + race_id
```
Each condition is embedded with `nn.Embedding(126, 512)` and passed into a cross-attention UNet.
---
## ποΈ Model Architecture
The model is a `UNet2DConditionModel` from π€ `diffusers`:
- **Input**: 3-channel RGB (128Γ128)
- **Cross-attention**: 512-dim from conditional embeddings
- **Scheduler**: DDPM with 1000 timesteps
- **Mixed precision**: Enabled via `torch.amp`
---
## πΌοΈ Generated Samples by Epoch
Samples below show **4 different conditions** (0β3) generated from models trained for increasing epochs.
| Sample |
| ----------------------------------- |
|  |
| |
---
## π§ͺ Dataset Preparation
```python
from datasets import load_dataset, concatenate_datasets
# Load both 0.25 and 1.25 splits
dataset_025_train = load_dataset("HuggingFaceM4/FairFace", "0.25", split="train")
dataset_025_val = load_dataset("HuggingFaceM4/FairFace", "0.25", split="validation")
dataset_125_train = load_dataset("HuggingFaceM4/FairFace", "1.25", split="train")
dataset_125_val = load_dataset("HuggingFaceM4/FairFace", "1.25", split="validation")
# Merge all splits
merged_dataset = concatenate_datasets([
dataset_025_train, dataset_025_val,
dataset_125_train, dataset_125_val
])
```
---
## 𧬠Training Loop
```python
for images, cond_ids in dataloader:
# Encode condition
emb = cond_emb(cond_ids).unsqueeze(1)
# Noise injection
t = torch.randint(0, 1000, (images.size(0),), device=device)
noise = torch.randn_like(images)
noisy = scheduler.add_noise(images, noise, t)
# Denoise
pred = unet(noisy, t, encoder_hidden_states=emb).sample
loss = F.mse_loss(pred, noise)
# Backprop
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
```
Each epoch:
- Trains on \~200,000 images
- Saves a model checkpoint
- Generates and saves 3 sample images
---
## πΎ Checkpoint Structure
Trained models are saved as:
```
./models/unet_epoch_{EPOCH}.pt
```
Samples are saved as:
```
./models/sample_epoch_{EPOCH}_{i}.png
```
---
## π§ Sample Generation During Inference
```python
x = torch.randn(1, 3, size, size).to(device)
for t in reversed(range(1000)):
t_tensor = torch.tensor([t], device=device)
pred = unet(x, t_tensor, encoder_hidden_states=emb).sample
x = scheduler.step(pred, t, x).prev_sample
```
These are used to produce a visual progression of model capability over time.
---
## β
Results Summary
- βοΈ Supports 126 discrete conditionings
- β‘ Mixed-precision training with `torch.amp`
- πΎ Saved checkpoints and visual samples per epoch
- π Structured dataset merging and label encoding
---
## π Outputs
```bash
./models/
βββ unet_epoch_10.pt
βββ unet_epoch_20.pt
βββ ...
βββ sample_epoch_10_0.png
βββ sample_epoch_10_1.png
βββ ...
./assets/
βββ grid_epochs_10_90.png
```
Hereβs the **MIT License** section formatted in markdown, including a direct link to the full license text:
---
## π License
This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details.