{"id":15683702,"url":"https://github.com/rcmalli/lightning-maml","last_synced_at":"2025-05-07T14:11:38.290Z","repository":{"id":39638586,"uuid":"349251119","full_name":"rcmalli/lightning-maml","owner":"rcmalli","description":"MAML Implementation using Pytorch-lightning","archived":false,"fork":false,"pushed_at":"2022-05-30T07:01:13.000Z","size":49,"stargazers_count":22,"open_issues_count":4,"forks_count":5,"subscribers_count":1,"default_branch":"main","last_synced_at":"2025-05-07T14:11:30.577Z","etag":null,"topics":["higher","hydra","pytorch","pytorch-lightning","torchmeta"],"latest_commit_sha":null,"homepage":"","language":"Python","has_issues":true,"has_wiki":null,"has_pages":null,"mirror_url":null,"source_name":null,"license":"mit","status":null,"scm":"git","pull_requests_enabled":true,"icon_url":"https://github.com/rcmalli.png","metadata":{"files":{"readme":"README.md","changelog":null,"contributing":null,"funding":null,"license":"LICENSE","code_of_conduct":null,"threat_model":null,"audit":null,"citation":null,"codeowners":null,"security":null,"support":null}},"created_at":"2021-03-19T00:06:12.000Z","updated_at":"2025-01-13T14:43:36.000Z","dependencies_parsed_at":"2022-09-20T07:01:40.468Z","dependency_job_id":null,"html_url":"https://github.com/rcmalli/lightning-maml","commit_stats":null,"previous_names":[],"tags_count":0,"template":false,"template_full_name":"grok-ai/nn-template","repository_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/rcmalli%2Flightning-maml","tags_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/rcmalli%2Flightning-maml/tags","releases_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/rcmalli%2Flightning-maml/releases","manifests_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories/rcmalli%2Flightning-maml/manifests","owner_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners/rcmalli","download_url":"https://codeload.github.com/rcmalli/lightning-maml/tar.gz/refs/heads/main","host":{"name":"GitHub","url":"https://github.com","kind":"github","repositories_count":252892504,"owners_count":21820648,"icon_url":"https://github.com/github.png","version":null,"created_at":"2022-05-30T11:31:42.601Z","updated_at":"2022-07-04T15:15:14.044Z","host_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub","repositories_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repositories","repository_names_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/repository_names","owners_url":"https://repos.ecosyste.ms/api/v1/hosts/GitHub/owners"}},"keywords":["higher","hydra","pytorch","pytorch-lightning","torchmeta"],"created_at":"2024-10-03T17:08:14.620Z","updated_at":"2025-05-07T14:11:38.266Z","avatar_url":"https://github.com/rcmalli.png","language":"Python","funding_links":[],"categories":[],"sub_categories":[],"readme":"# Pytorch Lightning MAML Implementation\n\n\u003cp align=\"center\"\u003e\n    \u003ca href=\"https://pytorch.org/get-started/locally/\"\u003e\u003cimg alt=\"PyTorch\" src=\"https://img.shields.io/badge/PyTorch-orange?logo=pytorch\"\u003e\u003c/a\u003e\n    \u003ca href=\"https://pytorchlightning.ai/\"\u003e\u003cimg alt=\"Lightning\" src=\"https://img.shields.io/badge/-Lightning-blueviolet\"\u003e\u003c/a\u003e\n    \u003ca href=\"https://hydra.cc/\"\u003e\u003cimg alt=\"Conf: hydra\" src=\"https://img.shields.io/badge/conf-hydra-blue\"\u003e\u003c/a\u003e\n    \u003ca href=\"https://wandb.ai/site\"\u003e\u003cimg alt=\"Logging: wandb\" src=\"https://img.shields.io/badge/logging-wandb-yellow\"\u003e\u003c/a\u003e\n    \u003ca href=\"https://black.readthedocs.io/en/stable/\"\u003e\u003cimg alt=\"Code style: black\" src=\"https://img.shields.io/badge/code%20style-black-000000.svg\"\u003e\u003c/a\u003e\n\u003c/p\u003e\n\nThis repository is the reimplementation\nof [MAML](https://arxiv.org/abs/1703.03400) (Model-Agnostic Meta-Learning)\nalgorithm. Differentiable optimizers are handled\nby [Higher](https://github.com/facebookresearch/higher) library\nand [NN-template](https://github.com/lucmos/nn-template) is used for structuring\nthe project. The default settings are used for training on Omniglot (5-way\n5-shot) problem. It can be easily extended for other few-shot datasets thanks to\n[Torchmeta](https://github.com/tristandeleu/pytorch-meta) library.\n\n## Quickstart\n\n**On Local Machine**\n\n1. Download and install dependencies\n\n```bash\ngit clone https://github.com/rcmalli/lightning-maml.git\ncd ./lightning-maml/\npip install -r requirements.txt\n```\n\n2. Create `.env` file containing the info given below using your\n   own [Wandb. ai](https://wandb.ai)\n   account to track experiments. You can use `.env.template` file.\n\n```bash\nexport DATASET_PATH=\"/your/project/root/data/\"\nexport WANDB_ENTITY=\"USERNAME\"\nexport WANDB_API_KEY=\"KEY\"\n```\n\n3. Run the experiment\n\n```bash\npython3 src/run.py train.pl_trainer.gpus=1\n```\n\n**On Google Colab**\n\n[![Google Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/rcmalli/lightning-maml/blob/main/notebooks/lightning_maml_pub.ipynb)\n\n## Results\n\n### Omniglot (5-way 5-shot)\n\nFew-shot learning using this dataset is easy task to overfit or learn for \nMAML algorithm.\n\n\u003ctable class=\"tg\"\u003e\n\u003cthead\u003e\n  \u003ctr\u003e\n    \u003cth colspan=\"3\"\u003e\u003c/th\u003e\n    \u003cth colspan=\"2\"\u003eMetatrain\u003c/th\u003e\n    \u003cth colspan=\"2\"\u003eMetavalidation\u003c/th\u003e\n  \u003c/tr\u003e\n\u003c/thead\u003e\n\u003ctbody\u003e\n  \u003ctr\u003e\n    \u003ctd \u003eAlgorithm\u003c/td\u003e\n    \u003ctd \u003eModel\u003c/td\u003e\n    \u003ctd \u003einner_steps\u003c/td\u003e\n    \u003ctd \u003einner accuracy\u003c/td\u003e\n    \u003ctd \u003e\u003cspan style=\"font-style:normal;text-decoration:none\"\u003eouter accuracy\u003c/span\u003e\u003c/td\u003e\n    \u003ctd \u003e\u003cspan style=\"font-style:normal;text-decoration:none\"\u003einner accuracy\u003c/span\u003e\u003c/td\u003e\n    \u003ctd \u003e\u003cspan style=\"font-style:normal;text-decoration:none\"\u003eouter accuracy\u003c/span\u003e\u003c/td\u003e\n  \u003c/tr\u003e\n  \u003ctr\u003e\n    \u003ctd \u003eMAML\u003c/td\u003e\n    \u003ctd \u003eOmniConv\u003c/td\u003e\n    \u003ctd \u003e1\u003c/td\u003e\n    \u003ctd \u003e0.992\u003c/td\u003e\n    \u003ctd \u003e0.992\u003c/td\u003e\n    \u003ctd \u003e0.98\u003c/td\u003e\n    \u003ctd \u003e0.98\u003c/td\u003e\n  \u003c/tr\u003e\n  \u003ctr\u003e\n    \u003ctd \u003eMAML\u003c/td\u003e\n    \u003ctd \u003eOmniConv\u003c/td\u003e\n    \u003ctd \u003e5\u003c/td\u003e\n    \u003ctd \u003e1.0\u003c/td\u003e\n    \u003ctd \u003e1.0\u003c/td\u003e\n    \u003ctd \u003e1.0\u003c/td\u003e\n    \u003ctd \u003e1.0\u003c/td\u003e\n  \u003c/tr\u003e\n\u003c/tbody\u003e\n\u003c/table\u003e\n\n\n\n## Customization\n\nInside 'conf' folder, you can change all the settings depending on your problem\nor dataset. The default parameters are set for Omniglot dataset. Here are some\nexamples for customization:\n\n### Debug on local machine without GPU\n\n```bash\npython3 src/run.py train.pl_trainer.gpus=0 train.pl_trainer.fast_dev_run=true\n```\n\n### Running more inner_steps and more epochs\n\n```bash\npython3 src/run.py train.pl_trainer.gpus=1  train.pl_trainer.max_epochs=1000 \\\ndata.datamodule.num_inner_steps=5\n```\n\n### Running weep of multiple runs\n\n```bash\npython3 src/run.py train.pl_trainer.gpus=1 data.datamodule.num_inner_steps=5,10,20 -m\n```\n\n### Using different dataset from Torchmeta\n\nIf you want to try a different dataset (ex. MiniImageNet), you can copy\ndefault.yaml file inside `conf/data` to `miniimagenet.yaml` and edit these\nlines :\n\n```yaml\ndatamodule:\n  _target_: pl.datamodule.MetaDataModule\n\n  datasets:\n    train:\n      _target_: torchmeta.datasets.MiniImagenet\n      root: ${env:DATASET_PATH}\n      meta_train: True\n      download: True\n\n    val:\n      _target_: torchmeta.datasets.MiniImagenet\n      root: ${env:DATASET_PATH}\n      meta_val: True\n      download: True\n\n    test:\n      _target_: torchmeta.datasets.MiniImagenet\n      root: ${env:DATASET_PATH}\n      meta_test: True\n      download: True\n\n# you may need to update data augmentation and preprocessing steps also!!!\n```\n\nRun the experiment as follows:\n\n```bash\npython3 src/run.py data=miniimagenet\n```\n\n\n## Implementing a different meta learning algorithm\n\nIf you plant to implement a new variant of MAML algorithm (for example \nMAML++) you can start by extending [default lightning module](https://github.com/rcmalli/lightning-maml/blob/44f271380bb6efc925a9070abe2ec4d0f7d88ef3/src/pl/model.py#L77) and its [step](https://github.com/rcmalli/lightning-maml/blob/44f271380bb6efc925a9070abe2ec4d0f7d88ef3/src/pl/model.py#L100-L150) \nfunction.\n\n## Notes\n\nThere are few required modifications run meta-learning algorithm using\npytorch-lightning as high-level library\n\n1. In supervised learning we have `M` mini-batches for each epoch. However, we\n   have `N` tasks for single meta-batch in meta learning settings. We have to\n   set our dataloader length to 1 otherwise, the dataloader will indefinitely\n   sample from the dataset.\n\n2. Apart from traditional test phase of supervised learning, we need gradient\n   computation also in test phase. Currently, pytorch-lightning does not allow\n   you to enable gradient computation by settings, you have to add single line\n   to your beginning of test and validation steps as following:\n   ```python\n    torch.set_grad_enabled(True)\n   ```\n3. In MAML algorithm, we have two different optimizers to train our model. Inner\n   optimizer must be differentiable and outer optimizer should update model\n   using updated weights inside inner iteration from support set and updates\n   from query set. In Pytorch-lightning optimizer are handled and weight updates\n   are done automatically. To disable this behaviour, we have to\n   set `automatic_optimization=False` and add following lines to handle backward\n   computations manually:\n   ```python\n   self.manual_backward(outer_loss, outer_optimizer)\n   outer_optimizer.step()\n   ```\n\n## References\n\n- [Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks](https://arxiv.org/abs/1703.03400)","project_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Frcmalli%2Flightning-maml","html_url":"https://awesome.ecosyste.ms/projects/github.com%2Frcmalli%2Flightning-maml","lists_url":"https://awesome.ecosyste.ms/api/v1/projects/github.com%2Frcmalli%2Flightning-maml/lists"}