Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/criteo-research/pytorch-ada

Another Domain Adaptation library, aimed at researchers.
https://github.com/criteo-research/pytorch-ada

Last synced: about 2 months ago
JSON representation

Another Domain Adaptation library, aimed at researchers.

Awesome Lists containing this project

README

        

# ADA: (Yet) Another Domain Adaptation library
[![Documentation Status](https://readthedocs.org/projects/pytorch-ada/badge/)](https://pytorch-ada.readthedocs.io/)
![Lint](https://github.com/criteo-research/pytorch-ada/workflows/lint/badge.svg)

## Context

The aim of ADA is to help researchers build new methods for unsupervised and semi-supervised domain adaptation. The library is built on top of [PyTorch-Lightning](https://pytorch-lightning.readthedocs.io/en/latest/new-project.html), enabling fast development of new models.

We built ADA with the idea of:

- minimizing the boilerplate when developing a new method (loading data from several domains, logging errors, switching from CPU to GPU).
- allowing fair comparison between methods by running all of them within the exact same environment.

You can find an introduction to ADA on [medium](https://medium.com/criteo-labs/introducing-ada-another-domain-adaptation-library-5df8b79378ee), and a more complete documentation on [pytorch-ada.readthedocs.io/](https://pytorch-ada.readthedocs.io/).

## Quick description

Three types of methods are available for unsupervised domain adaptation:

- Adversarial methods: Domain-adversarial neural networks ([DANN](https://arxiv.org/abs/1505.07818)) and Conditional Adversarial Domain Adaptation networks ([CDAN](https://papers.nips.cc/paper/7436-conditional-adversarial-domain-adaptation.pdf)),
- Optimal-Transport-based methods: Wasserstein distance guided representation learning ([WDGRL](https://arxiv.org/pdf/1707.01217.pdf)), for which we propose two implementations, the second one being a variant better adapted to the PyTorch-Lightning, allowing for multi-GPU training.
- MMD-based methods: Deep Adaptation Networks ([DAN](http://proceedings.mlr.press/v37/long15.pdf)) and Joint Adaptation Networks ([JAN](https://arxiv.org/pdf/1605.06636.pdf))

All these methods are implemented in [`models/architectures.py`](adalib/ada/models/architectures.py).

![The 3-block architecture for domain adaptation](docs/images/ada_blocks.png)

Adversarial and OT-based methods both rely on 3 networks:

- a feature extractor network mapping inputs $x\in\mathcal{X}$ to a latent space $\mathcal{Z}$,
- a task classifier network that learns to predict labels $y \in \mathcal{Y}$ from latent vectors,
- a domain classifier network that tries to predict whether samples in $\mathcal{Z}$ come from the source or target domain.

MMD-based methods don't use the critic network.

## Quick start

First you need to install the library. It has been tested with python 3.6+, with the latest versions of pytorch-lightning.

If you want to create a new conda environment, run:

```
conda env create -n adaenv python=3.7
conda activate adaenv
```

Install the library (with developer mode if you want to develop your own models later on, otherwise you can skip the `-e`):
```
pip install -e adalib
```

_Note_: on **Windows**, it could be necessary to first install pytorch and torchvision with conda:
```
conda install -c pytorch pytorch
conda install -c pytorch torchvision
pip install -e adalib
```

Run on of the scripts:
```
cd scripts
python run_simple.py
```

By default, this script launches experiments with all kinds of methods on a blobs dataset -- it doesn't take any parameter, you can change it easily from the script.
It may take a few minutes to finish.

Most parameters are available and can be changed through configuration files, which are all grouped in the `configs` folder:
- datasets
- network layers and training parameters
- methods (Source, DANN, CDAN...), and their specific parameters

## Advanced options

The script `run_full_options.py` runs the same kind of experiments allowing for more variants (semi-supervised, unbalanced, with gpus and MLFlow logging). You can run it without parameters or with `-h` to get help.

### MLFlow

You can log results to MLFlow.
Start a MLFlow server in another terminal:
```
conda activate adaenv
mlflow ui --port=31014
```

### Streamlit application

Optionally, you can use the `streamlit` app. First install `streamlit` with `pip install streamlit`, then launch the app like this:

```
streamlit run run_toys_app.py
```
This will start a web app with a default port = 8501, which you can view in your brower. It looks like this:
![Streamlit app screenshot](docs/images/streamlit_screenshot.png)

## Benchmarks results

### MNIST -> MNIST-M (5 runs)

|Method|source acc|target acc|
|:----|:---:|:---:|
|Source|89.0% +- 2.52|34.0% +- 1.71|
|DANN|94.2% +- 1.57|37.5% +- 2.85|
|CDAN|98.7% +- 0.19|68.4% +- 1.80|
|CDAN-E|98.7% +- 0.12|69.6% +- 1.51|
|DAN|98.0% +- 0.68|47.0% +- 1.85|
|JAN|96.4% +- 4.57|52.9% +- 2.16|
|WDGRL|93.9% +- 2.70|52.0% +- 4.82|

### MNIST -> USPS (5 runs)

|Method|source acc|target acc|
|:----|:---:|:---:|
|Source|99.2% +- 0.08|94.2% +- 1.07|
|DANN|99.1% +- 0.15|93.8% +- 1.06|
|CDAN|98.8% +- 0.17|90.7% +- 1.17|
|CDAN-E|98.9% +- 0.11|90.3% +- 0.98|
|DAN|99.0% +- 0.14|95.0% +- 0.83|
|JAN|98.6% +- 0.30|89.5% +- 2.00|
|WDGRL|98.7% +- 0.13|85.7% +- 6.57|

Checkout the [documentation benchmark page](https://pytorch-ada.readthedocs.io/en/latest/benchmarks.html) for more results.

## Contributing

### Code

You can find the latest version on github. Before submitting code, please run `black` to have clean code formatting:

```
pip install black
black .
```

### Documentation

First `pip` install `sphinx`, `sphinx-paramlinks`, ` recommonmark`.
Generate the documentation:

```
cd docs
sphinx-apidoc -o source/ ../adalib/ada ../scripts/
make html
```

## Citing

If this library is useful for your research please cite:

```
@misc{adalib2020,
title={(Yet) Another Domain Adaptation library},
author={Tousch, Anne-Marie and Renaudin, Christophe},
url={https://github.com/criteo-research/pytorch-ada},
year={2020}
}
```