https://github.com/ashmod/learning-not-to-learn
PyTorch implementation of the paper "Learning Not to Learn," an adversarial approach to mitigate dataset bias in deep neural networks.
https://github.com/ashmod/learning-not-to-learn
algorithms analysis data-science deep-learning machine-learning research
Last synced: about 1 month ago
JSON representation
PyTorch implementation of the paper "Learning Not to Learn," an adversarial approach to mitigate dataset bias in deep neural networks.
- Host: GitHub
- URL: https://github.com/ashmod/learning-not-to-learn
- Owner: ashmod
- License: mit
- Created: 2025-05-16T19:27:10.000Z (about 1 year ago)
- Default Branch: main
- Last Pushed: 2025-05-19T14:18:36.000Z (about 1 year ago)
- Last Synced: 2026-05-03T05:53:00.177Z (about 1 month ago)
- Topics: algorithms, analysis, data-science, deep-learning, machine-learning, research
- Language: Jupyter Notebook
- Homepage:
- Size: 950 KB
- Stars: 0
- Watchers: 0
- Forks: 0
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Learning Not To Learn: A User-Friendly Implementation
This repository provides a user-friendly PyTorch implementation of the paper "[Learning Not to Learn: Training Deep Neural Networks with Biased Data](https://arxiv.org/abs/1812.10352)" by Kim et al. (CVPR 2019). The goal of this project is to replicate the paper's adversarial approach to mitigate dataset bias in a way that is more accessible, configurable, and easier to understand and extend.
The core idea is to train a feature extractor that "unlearns" bias information by making it difficult for a separate bias prediction network to identify the bias from the extracted features, while still allowing a main task classifier to perform well.
## Features
* **Modular Codebase:** Clear separation of concerns for data loading, model definitions, training logic, and utilities.
* **Configuration-Driven:** Experiments are managed via YAML configuration files, making it easy to define and track different setups.
* **Automated Workflow:** A `Makefile` is provided for common tasks like setup, training, evaluation, and cleaning.
* **User-Friendly Setup:** Includes scripts for dataset download and clear instructions.
* **Phased Training:** Supports optional pre-training phases for components.
* **Comprehensive Logging:** Console logging and TensorBoard integration for tracking metrics.
* **Robust Checkpointing:** Saves and loads model checkpoints, including optimizer states, for resuming training and evaluation. Includes "best model" saving.
* **Gradient Reversal Layer (GRL):** Implements GRL for the adversarial training component.
* **Jupyter Notebooks:** For data exploration, model sanity checks, and (template for) results analysis.
## Project Structure
```
learning-not-to-learn/
├── .gitignore # Git ignore file
├── Makefile # Automation for common tasks
├── README.md # This file
├── requirements.txt # Python package dependencies
├── environment.yml # Optional: For Conda environments
│
├── configs/ # Configuration files for experiments
│ └── colored_mnist_default.yaml # Example config
│
├── data/ # Default location for datasets
│ └── colored_mnist/ # Data for Colored MNIST (after download script)
│ └── README.md # Brief note about dataset storage
│
├── notebooks/ # Jupyter notebooks
│ ├── 01_data_exploration_and_visualization.ipynb
│ ├── 02_model_sanity_checks.ipynb
│ └── 03_results_analysis.ipynb (template)
│
├── results/ # Output directory (logs, checkpoints, plots) - in .gitignore
│
├── scripts/ # Utility scripts
│ └── download_dataset.sh # Script to download and extract Colored MNIST
│
└── src/ # Main source code
├── __init__.py
├── data_loader.py # Dataset and DataLoader classes
├── main.py # Main script for training and evaluation
├── models.py # Neural network architecture definitions (f, g, h, GRL)
├── trainer.py # Trainer class with training/evaluation logic
└── utils.py # Utility functions (config loading, seeding, logging)
````
## Getting Started
### 1. Clone the Repository
```bash
git clone https://github.com/dizzydroid/learning-not-to-learn
cd learning-not-to-learn
````
### 2\. Create a Virtual Environment (Recommended)
**Using `venv`:**
```bash
python -m venv venv
source venv/bin/activate # On Windows: venv\Scripts\activate
```
**Using Conda:**
```bash
conda create -n lntl_env python=3.9 # Or your preferred Python version
conda activate lntl_env
```
### 3\. Setup Project using Makefile
The easiest way to install dependencies and download the dataset is by using the `Makefile`:
```bash
make setup
```
This command will:
1. Install Python dependencies from `requirements.txt` (including `gdown` for dataset download).
2. Download and extract the Colored MNIST dataset to `./data/colored_mnist/` using `scripts/download_dataset.sh`.
Alternatively, you can run these steps manually:
* Install dependencies: `make install` or `pip install -r requirements.txt && pip install gdown`
* Download data: `make download_data` or `bash scripts/download_dataset.sh`
## Configuration
Experiments are controlled by YAML configuration files located in the `configs/` directory. `configs/colored_mnist_default.yaml` is provided as a template. Modify this file or create new ones to define your experimental setup (dataset paths, model parameters, training hyperparameters, etc.).
You can list available configurations using:
```bash
make list_configs
```
## Running Experiments with Makefile
The `Makefile` provides convenient shortcuts for common operations. Run `make help` to see all available targets.
### Training
* **Run a full training session** (uses `DEFAULT_TRAIN_EPOCHS` from Makefile, e.g., 50 epochs, and `configs/colored_mnist_default.yaml`):
```bash
make train
```
* **Run a quick 1-epoch test training:**
```bash
make quick_train
```
* **Customize training:**
```bash
make train EPOCHS=100 CONFIG=configs/your_custom_config.yaml DEVICE=cuda:0
```
Outputs (logs, TensorBoard files, checkpoints) will be saved to the directory specified in the `logging.output_dir` and `project.experiment_name` fields of your config (e.g., `results/colored_mnist_baseline/`).
### Resuming Training
To resume training from a checkpoint, you'll need to use the `python src/main.py` command directly, specifying the checkpoint path:
```bash
python src/main.py --config configs/your_config.yaml --mode train --checkpoint_path path/to/your/checkpoint.pth
```
### Evaluation
* **Evaluate the best model** from the default experiment run:
```bash
make eval
```
* **Evaluate a specific checkpoint:**
```bash
make eval CHECKPOINT=./results/colored_mnist_baseline/checkpoints/checkpoint_epoch_XX.pth
```
### Monitoring with TensorBoard
* Launch TensorBoard for the default experiment:
```bash
make tensorboard
```
Then open `http://localhost:6006` in your browser.
## Jupyter Notebooks
The `notebooks/` directory contains:
* **`01_data_exploration_and_visualization.ipynb`**: Load and visualize the Colored MNIST dataset to understand its properties and biases.
* **`02_model_sanity_checks.ipynb`**: Instantiate and test individual model components ($f$, g, $h$) with dummy data to verify shapes and forward passes.
* **`03_results_analysis.ipynb`**: A template for loading metrics from TensorBoard, plotting learning curves, and performing other analyses on completed training runs.
Start Jupyter Lab/Notebook using:
```bash
make notebook
```
## Core Logic ("Learning Not To Learn")
1. **Feature Extractor ($f$)**: A network (e.g., CNN) that learns a representation of the input data.
2. **Task Classifier ($g$)**: A network (e.g., MLP) that takes features from $f$ and predicts the main task label.
3. **Bias Predictor ($h$)**: A network that takes features from $f$ and tries to predict a known bias in the data.
4. **Adversarial Training**:
* $f$ and $g$ are trained to minimize the main task classification loss.
* $h$ is trained to minimize its bias prediction loss.
* $f$ is *also* trained to *maximize* $h$'s bias prediction loss (typically via a Gradient Reversal Layer - GRL).
* The `training.adversarial_lambda` in the config controls the strength of this adversarial component.
The goal is for $f$ to learn features that are discriminative for the main task but contain minimal information about the unwanted bias.
## Citation
```bibtex
@inproceedings{kim2019learning,
title={Learning not to learn: Training deep neural networks with biased data},
author={Kim, Byungju and Kim, Hyunwoo and Kim, Kyungsu and Kim, Sungjin and Kim, Junmo},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={9012--9020},
year={2019}
}
```
## License
This project is licensed under the MIT License - see the [`LICENSE`](LICENSE) file for details.