Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/jfc43/self-training-ensembles
Propose a principled and practically effective framework for unsupervised accuracy estimation and error detection tasks with theoretical analysis and state-of-the-art performance.
https://github.com/jfc43/self-training-ensembles
deep-learning error-detection machine-learning pytorch self-training-ensembles unsupervised-accuracy-estimation
Last synced: about 1 month ago
JSON representation
Propose a principled and practically effective framework for unsupervised accuracy estimation and error detection tasks with theoretical analysis and state-of-the-art performance.
- Host: GitHub
- URL: https://github.com/jfc43/self-training-ensembles
- Owner: jfc43
- License: apache-2.0
- Created: 2021-10-14T19:07:03.000Z (almost 3 years ago)
- Default Branch: master
- Last Pushed: 2022-02-17T02:32:04.000Z (over 2 years ago)
- Last Synced: 2024-07-10T08:43:32.867Z (2 months ago)
- Topics: deep-learning, error-detection, machine-learning, pytorch, self-training-ensembles, unsupervised-accuracy-estimation
- Language: Python
- Homepage:
- Size: 237 KB
- Stars: 14
- Watchers: 1
- Forks: 1
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
- License: LICENSE
Awesome Lists containing this project
README
# Detecting Errors and Estimating Accuracy on Unlabeled Data with Self-training Ensembles
This repository is the official implementation of [Detecting Errors and Estimating Accuracy on Unlabeled Data with Self-training Ensembles](https://arxiv.org/abs/2106.15728).## Requirements
* It is tested under Ubuntu Linux 16.04.1 and Python 3.6 environment, and requires some packages to be installed: [PyTorch](https://pytorch.org/), [numpy](http://www.numpy.org/) and [scikit-learn](https://scikit-learn.org/).
* To install requirements: `pip install -r requirements.txt`## Downloading Datasets
* [MNIST-M](http://bit.ly/2fNqL6N): download it from the Google drive. Extract the files and place them in `./dataset/mnist_m/`.
* [SVHN](http://ufldl.stanford.edu/housenumbers/): need to download Format 2 data (`*.mat`). Place the files in `./dataset/svhn/`.
* [USPS](https://www.kaggle.com/bistaumanga/usps-dataset): download the usps.h5 file. Place the file in `./dataset/usps/`.## Overview of the Code
* `train_model.py`: train standard models via supervised learning.
* `train_dann.py`: train domain adaptive (DANN) models.
* `eval_pipeline.py`: evaluate various methods on all tasks.## Running Experiments
### Examples
* To train a standard model via supervised learning, you can use the following command:
`python train_model.py --source-dataset {source dataset} --model-type {model type} --base-dir {directory to save the model}`
`{source dataset}` can be `mnist`, `mnist-m`, `svhn` or `usps`.
`{model type}` can be `typical_dnn` or `dann_arch`.
* To train a domain adaptive (DANN) model, you can use the following command:
`python train_dann.py --source-dataset {source dataset} --target-dataset {target dataset} --base-dir {directory to save the model} [--test-time]`
`{source dataset}` (or `{target dataset}`) can be `mnist`, `mnist-m`, `svhn` or `usps`.
The argument `--test-time` is to indicate whether to replace the target training dataset with the target test dataset.
* To evaluate a method on all training-test dataset pairs, you can use the following command:
`python eval_pipeline.py --model-type {model type} --method {method}`
`{model type}` can be `typical_dnn` or `dann_arch`.
`{method}` can be `conf_avg`, `ensemble_conf_avg`, `conf`, `trust_score`, `proxy_risk`, `our_ri` or `our_rm`.
### Training
You can run the following scrips to pre-train all models needed for the experiments.
* `run_all_model_training.sh`: train all supervised learning models.
* `run_all_dann_training.sh`: train all DANN models.
* `run_all_ensemble_training.sh`: train all ensemble models.### Evaluation
You can run the following script to get the results reported in the paper.
* `run_all_evaluation.sh`: evaluate all methods on all tasks.### Pre-trained Models
We provide pre-trained models produced by our training scripts: `run_all_model_training.sh`, `run_all_dann_training.sh` and `run_all_ensemble_training.sh`.
You can download the pre-trained models from [Google Drive](https://drive.google.com/drive/folders/1PCUVBW1Wf1JqyN_goC1GiAi-sYFtO26e?usp=sharing).
## Experimental Results
![Main Results](results.png)## Acknowledgements
Part of this code is inspired by [estimating-generalization](https://github.com/chingyaoc/estimating-generalization) and [TrustScore](https://github.com/google/TrustScore).## Citation
Please cite our work if you use the codebase:
```
@article{chen2021detecting,
title={Detecting Errors and Estimating Accuracy on Unlabeled Data with Self-training Ensembles},
author={Chen, Jiefeng and Liu, Frederick and Avci, Besim and Wu, Xi and Liang, Yingyu and Jha, Somesh},
journal={Advances in Neural Information Processing Systems},
volume={34},
year={2021}
}
```## License
Please refer to the [LICENSE](LICENSE).