https://github.com/Han-Jia/UNICORN-MAML
PyTorch implementation of "How to Train Your MAML to Excel in Few-Shot Classification"
https://github.com/Han-Jia/UNICORN-MAML
few-shot-learning iclr2022 maml meta-learning
Last synced: 5 months ago
JSON representation
PyTorch implementation of "How to Train Your MAML to Excel in Few-Shot Classification"
- Host: GitHub
- URL: https://github.com/Han-Jia/UNICORN-MAML
- Owner: Han-Jia
- Created: 2021-06-03T13:12:46.000Z (over 4 years ago)
- Default Branch: main
- Last Pushed: 2024-06-29T17:16:56.000Z (over 1 year ago)
- Last Synced: 2024-11-15T06:32:07.500Z (11 months ago)
- Topics: few-shot-learning, iclr2022, maml, meta-learning
- Language: Python
- Homepage:
- Size: 3.97 MB
- Stars: 38
- Watchers: 1
- Forks: 7
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# How to Train Your MAML to Excel in Few-Shot Classification
The code repository for "[How to Train Your MAML to Excel in Few-Shot Classification](https://arxiv.org/abs/2106.16245)" (Accepted by ICLR 2022) in PyTorch.
If you use any content of this repo for your work, please cite the following bib entry:
@inproceedings{ye2021UNICORN,
author = {Han-Jia Ye and
Wei-Lun Chao},
title = {How to Train Your {MAML} to Excel in Few-Shot Classification},
booktitle = {10th International Conference on Learning Representations ({ICLR})},
year = {2021}
}## Main idea of UNICORN-MAML
Model-agnostic meta-learning (MAML) is arguably the most popular meta-learning algorithm nowadays, given its flexibility to incorporate various model architectures and to be applied to different problems. Nevertheless, its performance on few-shot classification is far behind many recent algorithms dedicated to the problem. In this paper, we point out several key facets of how to train MAML to excel in few-shot classification. First, we find that a large number of gradient steps are needed for the inner loop update, which contradicts the common usage of MAML for few-shot classification. Second, we find that MAML is sensitive to the permutation of class assignments in meta-testing: for a few-shot task of N classes, there are exponentially many ways to assign the learned initialization of the N-way classifier to the N classes, leading to an unavoidably huge variance. Third, we investigate several ways for permutation invariance and find that learning a shared classifier initialization for all the classes performs the best. On benchmark datasets such as *Mini*ImageNet and *Tiered*ImageNet, our approach, which we name UNICORN-MAML, performs on a par with or even outperforms state-of-the-art algorithms, **while keeping the simplicity of MAML without adding any extra sub-networks**.
## Standard Few-shot Learning Results
Experimental results on few-shot learning datasets with ResNet-12 backbone (Same as the [MetaOptNet](https://github.com/kjunelee/MetaOptNet)). We report average results with 10,000 randomly sampled few-shot learning episodes for stablized evaluation.
**MiniImageNet Dataset**
| Setups | 1-Shot 5-Way | 5-Shot 5-Way |
|:--------:|:------------:|:------------:|
| ProtoMAML | 62.62 | 79.24 |
| MetaOptNet | 62.64 | 78.63 |
| DeepEMD | 65.91 | 82.41 |
| FEAT | **66.78** | 82.05 |
| MAML | 64.42 | 83.44 |
| UNICORN-MAML | [65.17](https://drive.google.com/file/d/15496NKRBNrOpyyx3tQ_wD9fB2tx4BeT1/view?usp=sharing) | **[84.30](https://drive.google.com/file/d/1gjjQYOAyzoePKL4tvoag-bHPkGi6CmEQ/view?usp=sharing)** |**TieredImageNet Dataset**
| Setups | 1-Shot 5-Way | 5-Shot 5-Way |
|:--------:|:------------:|:------------:|
| ProtoMAML | 67.10 | 81.18 |
| MetaOptNet | 65.99 | 81.56 |
| DeepEMD | **71.52** | 86.03 |
| FEAT | 70.80 | 84.79 |
| MAML | 65.72 | 84.37 |
| UNICORN-MAML | 69.24 | **86.06** |## Prerequisites
The following packages are required to run the scripts:
- [PyTorch-1.6 and torchvision](https://pytorch.org)
- Package [tensorboardX](https://github.com/lanpa/tensorboardX)
- Dataset: please download the dataset and put images into the folder data/[name of the dataset, miniimagenet or cub]/images
- Pre-trained weights: The pre-trained weights (used for initialization) could be downloaded at [here](https://drive.google.com/drive/folders/1WiNF-qKm8yBH4KcC1cdW3gpEwrxTQ0qN?usp=sharing).
## Dataset
### MiniImageNet Dataset
The MiniImageNet dataset is a subset of the ImageNet that includes a total number of 100 classes and 600 examples per class. We follow the [previous setup](https://github.com/twitter/meta-learning-lstm), and use 64 classes as *base* categories, 16 and 20 as two sets of *novel* categories for model validation and evaluation, respectively.
### TieredImageNet Dataset
[TieredImageNet](https://github.com/renmengye/few-shot-ssl-public) is a large-scale dataset with more categories, which contains 351, 97, and 160 categoriesfor model training, validation, and evaluation, respectively.
## Code Structures
To reproduce our experiments with UNICORN-MAML, please use **train_fsl.py**. There are four parts in the code.
- `model`: It contains the main files of the code, including the few-shot learning trainer, the dataloader, the network architectures, and baseline and comparison models.
- `data`: Images and splits for the data sets.
- `saves`: The pre-trained weights of different networks.
- `checkpoints`: To save the trained models.## Model Training and Evaluation
Please use **train_fsl.py** and follow the instructions below. The file will automatically evaluate the model on the meta-test set with 10,000 tasks after given epochs.## Arguments
The train_fsl.py takes the following command line options (details are in the `model/utils.py`):**Task Related Arguments**
- `dataset`: Option for the dataset (`MiniImageNet`, `TieredImageNet`, or `CUB`), default to `MiniImageNet`- `way`: The number of classes in a few-shot task during meta-training, default to `5`
- `eval_way`: The number of classes in a few-shot task during meta-test, default to `5`
- `shot`: Number of instances in each class in a few-shot task during meta-training, default to `1`
- `eval_shot`: Number of instances in each class in a few-shot task during meta-test, default to `1`
- `query`: Number of instances in each class to evaluate the performance during meta-training, default to `15`
- `eval_query`: Number of instances in each class to evaluate the performance during meta-test, default to `15`
**Optimization Related Arguments**
- `max_epoch`: The maximum number of training epochs, default to `200`- `episodes_per_epoch`: The number of tasks sampled in each epoch, default to `100`
- `num_eval_episodes`: The number of tasks sampled from the meta-val set to evaluate the performance of the model (note that we fix sampling 10,000 tasks from the meta-test set during final evaluation), default to `200`
- `lr`: Learning rate for the model, default to `0.001` with pre-trained weights
- `lr_mul`: This is specially designed for set-to-set functions like FEAT. The learning rate for the top layer will be multiplied by this value (usually with faster learning rate). Default to `10`
- `lr_scheduler`: The scheduler to set the learning rate (`step`, `multistep`, or `cosine`), default to `step`
- `step_size`: The step scheduler to decrease the learning rate. Set it to a single value if choose the `step` scheduler and provide multiple values when choosing the `multistep` scheduler. Default to `20`
- `gamma`: Learning rate ratio for `step` or `multistep` scheduler, default to `0.1`
- `fix_BN`: Set the encoder to the evaluation mode during the meta-training. This parameter is useful when meta-learning with the WRN. Default to `False`
- `mom`: The momentum value for the SGD optimizer, default to `0.9`
- `weight_decay`: The weight_decay value for SGD optimizer, default to `0.0005`
**Model Related Arguments**
- `model_class`: The model to use during meta-learning. We provide implementations for `MAML` and our`MAMLUnicorn`. Default to `MAML`- `backbone_class`: Types of the encoder, i.e., ResNet-12 (`Res12`), default to `ConvNet`
- `temperature`: Temperature over the logits, we #divide# logits with this value. It is useful when meta-learning with pre-trained weights. Default to `0.5`
**Other Arguments**
- `gpu`: The index of GPU to use. Please provide multiple indexes if choose `multi_gpu`. Default to `0`
- `log_interval`: How often to log the meta-training information, default to every `50` tasks
- `eval_interval`: How often to validate the model over the meta-val set, default to every `1` epoch
- `save_dir`: The path to save the learned models, default to `./checkpoints`
Running the command without arguments will train the models with the default hyper-parameter values. Loss changes will be recorded as a tensorboard file.
## Training scripts for UNICORN-MAML
For example, to train the 1-shot/5-shot 5-way MAML/UNICORN-MAML model with ResNet-12 backbone on MiniImageNet:
$ python train_fsl.py --max_epoch 100 --way 5 --eval_way 5 --lr_scheduler step --model_class MAML --lr_mul 10 --backbone_class Res12 --dataset MiniImageNet --gpu 0 --query 15 --step_size 20 --gamma 0.1 --para_init './saves/initialization/miniimagenet/Res12-pre.pth' --lr 0.001 --shot 1 --eval_shot 1 --temperature 0.5 --gd_lr 0.05 --inner_iters 15
$ python train_fsl.py --max_epoch 100 --way 5 --eval_way 5 --lr_scheduler step --model_class MAML --lr_mul 10 --backbone_class Res12 --dataset MiniImageNet --gpu 0 --query 15 --step_size 20 --gamma 0.1 --para_init './saves/initialization/miniimagenet/Res12-pre.pth' --lr 0.001 --shot 5 --eval_shot 5 --temperature 0.5 --gd_lr 0.1 --inner_iters 20
$ python train_fsl.py --max_epoch 100 --way 5 --eval_way 5 --lr_scheduler step --model_class MAMLUnicorn --lr_mul 10 --backbone_class Res12 --dataset MiniImageNet --gpu 0 --query 15 --step_size 20 --gamma 0.1 --para_init './saves/initialization/miniimagenet/Res12-pre.pth' --lr 0.001 --shot 1 --eval_shot 1 --temperature 0.5 --gd_lr 0.1 --inner_iters 5
$ python train_fsl.py --max_epoch 100 --way 5 --eval_way 5 --lr_scheduler step --model_class MAMLUnicorn --lr_mul 10 --backbone_class Res12 --dataset MiniImageNet --gpu 0 --query 15 --step_size 20 --gamma 0.1 --para_init './saves/initialization/miniimagenet/Res12-pre.pth' --lr 0.001 --shot 5 --eval_shot 5 --temperature 0.5 --gd_lr 0.1 --inner_iters 20to train the 1-shot/5-shot 5-way MAML/UNICORN-MAML model with ResNet-12 backbone on TieredImageNet:
$ python train_fsl.py --max_epoch 100 --way 5 --eval_way 5 --lr_scheduler step --model_class MAML --lr_mul 10 --backbone_class Res12 --dataset TieredImageNet --gpu 0 --query 15 --step_size 20 --gamma 0.1 --para_init './saves/initialization/tieredimagenet/Res12-pre.pth' --lr 0.001 --shot 1 --eval_shot 1 --temperature 0.5 --gd_lr 0.01 --inner_iters 20
$ python train_fsl.py --max_epoch 100 --way 5 --eval_way 5 --lr_scheduler step --model_class MAML --lr_mul 10 --backbone_class Res12 --dataset TieredImageNet --gpu 0 --query 15 --step_size 20 --gamma 0.1 --para_init './saves/initialization/tieredimagenet/Res12-pre.pth' --lr 0.001 --shot 1 --eval_shot 5 --temperature 0.5 --gd_lr 0.05 --inner_iters 15
$ python train_fsl.py --max_epoch 100 --way 5 --eval_way 5 --lr_scheduler step --model_class MAMLUnicorn --lr_mul 10 --backbone_class Res12 --dataset TieredImageNet --gpu 0 --query 15 --step_size 20 --gamma 0.1 --para_init './saves/initialization/tieredimagenet/Res12-pre.pth' --lr 0.001 --shot 5 --eval_shot 1 --temperature 0.5 --gd_lr 0.02 --inner_iters 10
$ python train_fsl.py --max_epoch 100 --way 5 --eval_way 5 --lr_scheduler step --model_class MAMLUnicorn --lr_mul 10 --backbone_class Res12 --dataset TieredImageNet --gpu 0 --query 15 --step_size 20 --gamma 0.1 --para_init './saves/initialization/tieredimagenet/Res12-pre.pth' --lr 0.001 --shot 1 --eval_shot 5 --temperature 0.5 --gd_lr 0.05 --inner_iters 20## Verifying the permutation variance of a learned MAML model
We can evaluate a learned MAML model and check whether the permutation will introduce large variance. For example, 1-shot/5-shot 5-way model with ResNet-12 backbone on MiniImageNet:
$ python eval_maml_permutation.py --shot_list 1 --model_path './MAML-1-shot.pth' --gpu 0 --gd_lr 0.05 --inner_iters 15 --model_class MAML --dataset MiniImageNet
$ python eval_maml_permutation.py --shot_list 5 --model_path './MAML-5-shot.pth' --gpu 0 --gd_lr 0.1 --inner_iters 20 --model_class MAML --dataset MiniImageNet## Acknowledgment
We thank the following repos providing helpful components/functions in our work.- [FEAT](https://github.com/Sha-Lab/FEAT)
- [AVIATOR](https://github.com/Han-Jia/AVIATOR)