Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/kristpapadopoulos/keras-stochastic-weight-averaging
Keras callback function for stochastic weight averaging
https://github.com/kristpapadopoulos/keras-stochastic-weight-averaging
callback callback-functions deep-learning ensemble keras keras-callback keras-implementations stochastic-weight-averaging stochasticweightaveraging weightaveraging weights
Last synced: 3 months ago
JSON representation
Keras callback function for stochastic weight averaging
- Host: GitHub
- URL: https://github.com/kristpapadopoulos/keras-stochastic-weight-averaging
- Owner: kristpapadopoulos
- Created: 2018-07-04T05:47:56.000Z (over 6 years ago)
- Default Branch: master
- Last Pushed: 2022-06-11T14:19:17.000Z (over 2 years ago)
- Last Synced: 2024-08-01T15:30:14.652Z (6 months ago)
- Topics: callback, callback-functions, deep-learning, ensemble, keras, keras-callback, keras-implementations, stochastic-weight-averaging, stochasticweightaveraging, weightaveraging, weights
- Language: Python
- Homepage:
- Size: 7.81 KB
- Stars: 55
- Watchers: 3
- Forks: 22
- Open Issues: 0
-
Metadata Files:
- Readme: README.md
Awesome Lists containing this project
README
### Stochastic Weight Averaging with Keras callback function
Stochastic Weight Averaging following paper [Averaging Weights Leads to Wider Optima and Better Generalization
](https://arxiv.org/abs/1803.05407)The file swa.py contains an implementation for stochastic weight averaging (SWA) with a constant learning rate for a user defined amount of epochs.
Callback is instantiated with filename for saving the final weights of the model after SWA and the number of epochs to average.
Example
The total number of training epochs 150, SWA to start from epoch 140 to average last 10 epochs.
```
from swa import SWA# specify number of training epochs
number_of_epochs = 150# specify the start epoch of stochastic weight averaging
swa = SWA(140, filepath = None)# call SWA during model fitting
model.fit(..., callbacks = [swa])
```