Ecosyste.ms: Awesome

An open API service indexing awesome lists of open source software.

Awesome Lists | Featured Topics | Projects

https://github.com/jfc43/self-training-ensembles

Propose a principled and practically effective framework for unsupervised accuracy estimation and error detection tasks with theoretical analysis and state-of-the-art performance.
https://github.com/jfc43/self-training-ensembles

deep-learning error-detection machine-learning pytorch self-training-ensembles unsupervised-accuracy-estimation

Last synced: about 1 month ago
JSON representation

Propose a principled and practically effective framework for unsupervised accuracy estimation and error detection tasks with theoretical analysis and state-of-the-art performance.

Awesome Lists containing this project

README

        

# Detecting Errors and Estimating Accuracy on Unlabeled Data with Self-training Ensembles
This repository is the official implementation of [Detecting Errors and Estimating Accuracy on Unlabeled Data with Self-training Ensembles](https://arxiv.org/abs/2106.15728).

## Requirements
* It is tested under Ubuntu Linux 16.04.1 and Python 3.6 environment, and requires some packages to be installed: [PyTorch](https://pytorch.org/), [numpy](http://www.numpy.org/) and [scikit-learn](https://scikit-learn.org/).
* To install requirements: `pip install -r requirements.txt`

## Downloading Datasets
* [MNIST-M](http://bit.ly/2fNqL6N): download it from the Google drive. Extract the files and place them in `./dataset/mnist_m/`.
* [SVHN](http://ufldl.stanford.edu/housenumbers/): need to download Format 2 data (`*.mat`). Place the files in `./dataset/svhn/`.
* [USPS](https://www.kaggle.com/bistaumanga/usps-dataset): download the usps.h5 file. Place the file in `./dataset/usps/`.

## Overview of the Code
* `train_model.py`: train standard models via supervised learning.
* `train_dann.py`: train domain adaptive (DANN) models.
* `eval_pipeline.py`: evaluate various methods on all tasks.

## Running Experiments

### Examples

* To train a standard model via supervised learning, you can use the following command:

`python train_model.py --source-dataset {source dataset} --model-type {model type} --base-dir {directory to save the model}`

`{source dataset}` can be `mnist`, `mnist-m`, `svhn` or `usps`.

`{model type}` can be `typical_dnn` or `dann_arch`.

* To train a domain adaptive (DANN) model, you can use the following command:

`python train_dann.py --source-dataset {source dataset} --target-dataset {target dataset} --base-dir {directory to save the model} [--test-time]`

`{source dataset}` (or `{target dataset}`) can be `mnist`, `mnist-m`, `svhn` or `usps`.

The argument `--test-time` is to indicate whether to replace the target training dataset with the target test dataset.

* To evaluate a method on all training-test dataset pairs, you can use the following command:

`python eval_pipeline.py --model-type {model type} --method {method}`

`{model type}` can be `typical_dnn` or `dann_arch`.

`{method}` can be `conf_avg`, `ensemble_conf_avg`, `conf`, `trust_score`, `proxy_risk`, `our_ri` or `our_rm`.

### Training

You can run the following scrips to pre-train all models needed for the experiments.
* `run_all_model_training.sh`: train all supervised learning models.
* `run_all_dann_training.sh`: train all DANN models.
* `run_all_ensemble_training.sh`: train all ensemble models.

### Evaluation

You can run the following script to get the results reported in the paper.
* `run_all_evaluation.sh`: evaluate all methods on all tasks.

### Pre-trained Models

We provide pre-trained models produced by our training scripts: `run_all_model_training.sh`, `run_all_dann_training.sh` and `run_all_ensemble_training.sh`.

You can download the pre-trained models from [Google Drive](https://drive.google.com/drive/folders/1PCUVBW1Wf1JqyN_goC1GiAi-sYFtO26e?usp=sharing).

## Experimental Results
![Main Results](results.png)

## Acknowledgements
Part of this code is inspired by [estimating-generalization](https://github.com/chingyaoc/estimating-generalization) and [TrustScore](https://github.com/google/TrustScore).

## Citation
Please cite our work if you use the codebase:
```
@article{chen2021detecting,
title={Detecting Errors and Estimating Accuracy on Unlabeled Data with Self-training Ensembles},
author={Chen, Jiefeng and Liu, Frederick and Avci, Besim and Wu, Xi and Liang, Yingyu and Jha, Somesh},
journal={Advances in Neural Information Processing Systems},
volume={34},
year={2021}
}
```

## License
Please refer to the [LICENSE](LICENSE).