Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

https://github.com/jakesnell/prototypical-networks

Code for the NeurIPS 2017 Paper "Prototypical Networks for Few-shot Learning"
https://github.com/jakesnell/prototypical-networks

deep-learning few-shot metric-learning nips-2017 omniglot pytorch

Last synced: 11 days ago
JSON representation

Code for the NeurIPS 2017 Paper "Prototypical Networks for Few-shot Learning"

Lists

README

        

# Prototypical Networks for Few-shot Learning

Code for the NIPS 2017 paper [Prototypical Networks for Few-shot Learning](http://papers.nips.cc/paper/6996-prototypical-networks-for-few-shot-learning.pdf).

If you use this code, please cite our paper:

```
@inproceedings{snell2017prototypical,
title={Prototypical Networks for Few-shot Learning},
author={Snell, Jake and Swersky, Kevin and Zemel, Richard},
booktitle={Advances in Neural Information Processing Systems},
year={2017}
}
```

## Training a prototypical network

### Install dependencies

* This code has been tested on Ubuntu 16.04 with Python 3.6 and PyTorch 0.4.
* Install [PyTorch and torchvision](http://pytorch.org/).
* Install [torchnet](https://github.com/pytorch/tnt) by running `pip install git+https://github.com/pytorch/tnt.git@master`.
* Install the protonets package by running `python setup.py install` or `python setup.py develop`.

### Set up the Omniglot dataset

* Run `sh download_omniglot.sh`.

### Train the model

* Run `python scripts/train/few_shot/run_train.py`. This will run training and place the results into `results`.
* You can specify a different output directory by passing in the option `--log.exp_dir EXP_DIR`, where `EXP_DIR` is your desired output directory.
* If you are running on a GPU you can pass in the option `--data.cuda`.
* Re-run in trainval mode `python scripts/train/few_shot/run_trainval.py`. This will save your model into `results/trainval` by default.

### Evaluate

* Run evaluation as: `python scripts/predict/few_shot/run_eval.py --model.model_path results/trainval/best_model.pt`.