Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/jakesnell/prototypical-networks
Code for the NeurIPS 2017 Paper "Prototypical Networks for Few-shot Learning"
https://github.com/jakesnell/prototypical-networks
deep-learning few-shot metric-learning nips-2017 omniglot pytorch
Last synced: 10 days ago
JSON representation
Code for the NeurIPS 2017 Paper "Prototypical Networks for Few-shot Learning"
- Host: GitHub
- URL: https://github.com/jakesnell/prototypical-networks
- Owner: jakesnell
- License: mit
- Created: 2017-11-05T04:38:19.000Z (about 7 years ago)
- Default Branch: master
- Last Pushed: 2021-01-28T09:22:04.000Z (almost 4 years ago)
- Last Synced: 2024-10-15T10:05:46.780Z (25 days ago)
- Topics: deep-learning, few-shot, metric-learning, nips-2017, omniglot, pytorch
- Language: Python
- Homepage:
- Size: 201 KB
- Stars: 1,110
- Watchers: 13
- Forks: 253
- Open Issues: 26
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
- awesome-few-shot-meta-learning - code - official (PyTorch)
README
# Prototypical Networks for Few-shot Learning
Code for the NIPS 2017 paper [Prototypical Networks for Few-shot Learning](http://papers.nips.cc/paper/6996-prototypical-networks-for-few-shot-learning.pdf).
If you use this code, please cite our paper:
```
@inproceedings{snell2017prototypical,
title={Prototypical Networks for Few-shot Learning},
author={Snell, Jake and Swersky, Kevin and Zemel, Richard},
booktitle={Advances in Neural Information Processing Systems},
year={2017}
}
```## Training a prototypical network
### Install dependencies
* This code has been tested on Ubuntu 16.04 with Python 3.6 and PyTorch 0.4.
* Install [PyTorch and torchvision](http://pytorch.org/).
* Install [torchnet](https://github.com/pytorch/tnt) by running `pip install git+https://github.com/pytorch/tnt.git@master`.
* Install the protonets package by running `python setup.py install` or `python setup.py develop`.### Set up the Omniglot dataset
* Run `sh download_omniglot.sh`.
### Train the model
* Run `python scripts/train/few_shot/run_train.py`. This will run training and place the results into `results`.
* You can specify a different output directory by passing in the option `--log.exp_dir EXP_DIR`, where `EXP_DIR` is your desired output directory.
* If you are running on a GPU you can pass in the option `--data.cuda`.
* Re-run in trainval mode `python scripts/train/few_shot/run_trainval.py`. This will save your model into `results/trainval` by default.### Evaluate
* Run evaluation as: `python scripts/predict/few_shot/run_eval.py --model.model_path results/trainval/best_model.pt`.