Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/google-research/l2p
Learning to Prompt (L2P) for Continual Learning @ CVPR22 and DualPrompt: Complementary Prompting for Rehearsal-free Continual Learning @ ECCV22
https://github.com/google-research/l2p
continual-learning deep-learning jax
Last synced: 3 months ago
JSON representation
Learning to Prompt (L2P) for Continual Learning @ CVPR22 and DualPrompt: Complementary Prompting for Rehearsal-free Continual Learning @ ECCV22
- Host: GitHub
- URL: https://github.com/google-research/l2p
- Owner: google-research
- License: apache-2.0
- Created: 2021-12-03T22:50:15.000Z (almost 3 years ago)
- Default Branch: main
- Last Pushed: 2022-12-08T08:06:01.000Z (almost 2 years ago)
- Last Synced: 2024-05-09T17:11:57.551Z (6 months ago)
- Topics: continual-learning, deep-learning, jax
- Language: Python
- Homepage: https://arxiv.org/pdf/2112.08654.pdf
- Size: 395 KB
- Stars: 374
- Watchers: 8
- Forks: 41
- Open Issues: 6
-
Metadata Files:
- Readme: README.md
- Contributing: CONTRIBUTING.md
- License: LICENSE
Awesome Lists containing this project
README
# Prompt-based Continual Learning Official Jax Implementation
This codebase contains the implementation of two continual learning methods:
- **[Learning to Prompt for Continual Learning (L2P)](https://arxiv.org/pdf/2112.08654.pdf) (CVPR2022)** [[Google AI Blog]](https://ai.googleblog.com/2022/04/learning-to-prompt-for-continual.html)
- **[DualPrompt: Complementary Prompting for Rehearsal-free Continual Learning](https://arxiv.org/pdf/2204.04799.pdf) (ECCV2022)**## Introduction
L2P is a novel continual learning technique which learns to dynamically prompt a pre-trained model to learn tasks sequentially under different task transitions. Different from mainstream rehearsal-based or architecture-based methods, L2P requires neither a rehearsal buffer nor test-time task identity. L2P can be generalized to various continual learning settings including the most challenging and realistic task-agnostic setting. L2P consistently outperforms prior state-of-the-art methods. Surprisingly, L2P achieves competitive results against rehearsal-based methods even without a rehearsal buffer.
DualPrompt improves L2P by attaching complementary prompts to the pre-trained backbone, and then formulates the objective as learning task-invariant and task-specific “instructions". With extensive experimental validation, DualPrompt consistently sets state-of-the-art performance under the challenging class-incremental setting. In particular, DualPrompt outperforms recent advanced continual learning methods with relatively large buffer sizes. We also introduce a more challenging benchmark, Split ImageNet-R, to help generalize rehearsal-free continual learning research.
Code is written by Zifeng Wang. Acknowledgement to https://github.com/google-research/nested-transformer.
This is not an officially supported Google product.
## Novel CL benchmark: Split ImageNet-R
The Split ImageNet-R benchmark is build upon [ImageNet-R](https://www.tensorflow.org/datasets/catalog/imagenet_r) by dividing the 200 classes into 10 tasks with 20 classes per task, see [libml/input_pipeline.py](libml/input_pipeline.py) for details. We believe the Split ImageNet-R is of great importance to the continual learning community, for the following reasons:- Split ImageNet-R contains classes with different styles, which is closer to the complicated real-world problems.
- The significant intra-class diversity poses a great challenge for rehearsal-based methods to work effectively with a small buffer size, thus encouraging the development of more practical, rehearsal-free methods.
- Pre-trained vision models are useful in practical continual learning. However, their training set usually includes ImageNet. Thus, Split ImageNet-R serves as a relative fair and challenging benchmark, and an alternative to ImageNet-based benchmarks for continual learning that uses pre-trained models.## PyTorch Reimplementation
The codebase has been reimplemented in PyTorch by Jaeho Lee in [l2p-pytorch](https://github.com/JH-LEE-KR/l2p-pytorch) and [dualprompt-pytorch](https://github.com/JH-LEE-KR/dualprompt-pytorch).## Enviroment setup
```
pip install -r requirements.txt
```
After this, you may need to adjust your jax version according to your cuda driver version so that jax correctly identifies your GPUs (see [this issue](https://github.com/google/jax/issues/5231) for more details).Note: The codebase has been throughly tested under the TPU enviroment using the newest JAX version. We are currently working on further verifying the GPU environment.
## Dataset preparation
Before running experiments for 5-datasets and CORe50, additional dataset preparation step should be conducted as follows:1. Download CORe50 classification benchmark here: https://vlomonaco.github.io/core50/ and download not-mnist here: http://yaroslavvb.com/upload/notMNIST/
2. Transform them into TFDS compatible form following the tutorial in https://www.tensorflow.org/datasets/add_dataset
3. Replace corresponding dataset paths `"PATH_TO_CORE50"` and `"PATH_TO_NOT_MNIST"` in [libml/input_pipeline.py](libml/input_pipeline.py) by the destination paths in step 2## Getting pretrained ViT model
ViT-B/16 model used in this paper can be downloaded at [here](https://storage.googleapis.com/vit_models/imagenet21k/ViT-B_16.npz).
Note: Our codebase actually supports various sizes of ViTs. If you would like to try variations of ViTs, feel free to change the `config.model_name` in the config files, following the valid options defined in [models/vit.py](models/vit.py).## Instructions on running L2P and DualPrompt
We provide the configuration file to train and evaluate L2P and DualPrompt on multiple benchmarks in [configs](configs/).To run L2P on benchmark datasets:
```
python main.py --my_config configs/$L2P_CONFIG --workdir=./l2p --my_config.init_checkpoint=
```
where `$L2P_CONFIG` can be one of the followings: `[cifar100_l2p.py, five_datasets_l2p.py, core50_l2p.py, cifar100_gaussian_l2p.py]`.Note: we run our experiments using 8 V100 GPUs or 4 TPUs, and we specify a per device batch size of 16 in the config files. This indicates that we use a total batch size of 128.
To run DualPrompt on benchmark datasets:
```
python main.py --my_config configs/$DUALPROMPT_CONFIG --workdir=./dualprompt --my_config.init_checkpoint=
```
where `$DUALPROMPT_CONFIG` can be one of the followings: `[imr_dualprompt.py, cifar100_dualprompt.py]`.## Visualize results
We use tensorboard to visualize the result. For example, if the working directory specified to run L2P is `workdir=./cifar100_l2p`, the command to check result is as follows:```
tensorboard --logdir ./cifar100_l2p
```
Here are the important metrics to keep track of, and their corresponding meanings:| Metric | Description |
| ----------- | ----------- |
| accuracy_n | Accuracy of the n-th task |
| forgetting | Average forgetting up until the current task |
| avg_acc | Average evaluation accuracy up until the current task |## Cite
```
@inproceedings{wang2022learning,
title={Learning to prompt for continual learning},
author={Wang, Zifeng and Zhang, Zizhao and Lee, Chen-Yu and Zhang, Han and Sun, Ruoxi and Ren, Xiaoqi and Su, Guolong and Perot, Vincent and Dy, Jennifer and Pfister, Tomas},
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
pages={139--149},
year={2022}
}
``````
@article{wang2022dualprompt,
title={DualPrompt: Complementary Prompting for Rehearsal-free Continual Learning},
author={Wang, Zifeng and Zhang, Zizhao and Ebrahimi, Sayna and Sun, Ruoxi and Zhang, Han and Lee, Chen-Yu and Ren, Xiaoqi and Su, Guolong and Perot, Vincent and Dy, Jennifer and others},
journal={European Conference on Computer Vision},
year={2022}
}
```