https://github.com/markdtw/temperature-scaling-tensorflow
On Calibration of Modern Neural Networks - tensorflow implementation
https://github.com/markdtw/temperature-scaling-tensorflow
Last synced: 8 months ago
JSON representation
On Calibration of Modern Neural Networks - tensorflow implementation
- Host: GitHub
- URL: https://github.com/markdtw/temperature-scaling-tensorflow
- Owner: markdtw
- Created: 2018-06-16T09:57:56.000Z (almost 8 years ago)
- Default Branch: master
- Last Pushed: 2018-06-16T11:19:18.000Z (almost 8 years ago)
- Last Synced: 2025-04-09T02:06:11.405Z (about 1 year ago)
- Language: Python
- Size: 5.86 KB
- Stars: 30
- Watchers: 4
- Forks: 12
- Open Issues: 1
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
# Temperature Scaling tensorflow
Tensorflow implementation of [On Calibration of Modern Neural Networks](https://arxiv.org/abs/1706.04599).
What this repo can do:
- Train ResNet_v1_110
- Calibrate it's output on CIFAR-10/100
- Using ```temp_scaling``` function to calibrate any of your networks using tensorflow.
What this repo *cannot* do:
- Calculate ECE (Expected Calibration Error)
Official PyTorch implementation by @gpleiss [here](https://github.com/gpleiss/temperature_scaling).
## Prerequisites
- Python 3.5
- [NumPy](http://www.numpy.org/)
- [TensorFlow 1.8](https://www.tensorflow.org/)
## Data
- [CIFAR-10/100](https://www.cs.toronto.edu/~kriz/cifar.html)
## Preparation
- Create `data/` folder, download and extract the python version from CIFAR webpage.
## Train
First, train the model (ResNet 110 in this case) using default parameters:
```bash
python main.py
```
Check out tunable hyper-parameters:
```bash
python main.py --help
```
## Temperature Scaling
Then, do temperature scaling to calibrate your model on the validation set.
```bash
python temp_scaling.py
```
Use the ```temp_var``` returned by ```temp_scaling``` function with your models logits to get calibrated output.
## Notes
- ResNet_v1_110 is trained for 250 epochs with other default parameters introduced in the original ResNet paper.
- The identity shortcut in ResNet_v1_110 is replaced with projection shortcut, meaning there are two additional convolutional layers.
- Validation accuracy and test accuracy on CIFAR-100 are around 70%.
- Issues are welcome!
## Resources
- [The paper](https://arxiv.org/abs/1706.04599).
- [Official PyTorch Implementation](https://github.com/gpleiss/temperature_scaling)