An open API service indexing awesome lists of open source software.

https://github.com/idiap/importance-sampling

Code for experiments regarding importance sampling for training neural networks
https://github.com/idiap/importance-sampling

Last synced: about 1 year ago
JSON representation

Code for experiments regarding importance sampling for training neural networks

Awesome Lists containing this project

README

          

Importance Sampling
====================

This python package provides a library that accelerates the training of
arbitrary neural networks created with `Keras `__ using
**importance sampling**.

.. code:: python

# Keras imports

from importance_sampling.training import ImportanceTraining

x_train, y_train, x_val, y_val = load_data()
model = create_keras_model()
model.compile(
optimizer="adam",
loss="categorical_crossentropy",
metrics=["accuracy"]
)

ImportanceTraining(model).fit(
x_train, y_train,
batch_size=32,
epochs=10,
verbose=1,
validation_data=(x_val, y_val)
)

model.evaluate(x_val, y_val)

Importance sampling for Deep Learning is an active research field and this
library is undergoing development so your mileage may vary.

Relevant Research
-----------------

**Ours**

* Not All Samples Are Created Equal: Deep Learning with Importance Sampling [`preprint `__]
* Biased Importance Sampling for Deep Neural Network Training [`preprint `__]

**By others**

* Stochastic optimization with importance sampling for regularized loss
minimization [`pdf `__]
* Variance reduction in SGD by distributed importance sampling [`pdf `__]

Dependencies & Installation
---------------------------

Normally if you already have a functional Keras installation you just need to
``pip install keras-importance-sampling``.

* ``Keras`` > 2
* A Keras backend among *Tensorflow*, *Theano* and *CNTK*
* ``blinker``
* ``numpy``
* ``matplotlib``, ``seaborn``, ``scikit-learn`` are optional (used by the plot
scripts)

Documentation
-------------

The module has a dedicated `documentation site
`__ but you can also read the
`source code `__ and the `examples
`__ to get an
idea of how the library should be used and extended.

Examples
---------

In the ``examples`` folder you can find some Keras examples that have been edited
to use importance sampling.

Code examples
*************

In this section we will showcase part of the API that can be used to train
neural networks with importance sampling.

.. code:: python

# Import what is needed to build the Keras model
from keras import backend as K
from keras.layers import Dense, Activation, Flatten
from keras.models import Sequential

# Import a toy dataset and the importance training
from importance_sampling.datasets import MNIST
from importance_sampling.training import ImportanceTraining

def create_nn():
"""Build a simple fully connected NN"""
model = Sequential([
Flatten(input_shape=(28, 28, 1)),
Dense(40, activation="tanh"),
Dense(40, activation="tanh"),
Dense(10),
Activation("softmax") # Needs to be separate to automatically
# get the preactivation outputs
])

model.compile(
optimizer="adam",
loss="categorical_crossentropy",
metrics=["accuracy"]
)

return model

if __name__ == "__main__":
# Load the data
dataset = MNIST()
x_train, y_train = dataset.train_data[:]
x_test, y_test = dataset.test_data[:]

# Create the NN and keep the initial weights
model = create_nn()
weights = model.get_weights()

# Train with uniform sampling
K.set_value(model.optimizer.lr, 0.01)
model.fit(
x_train, y_train,
batch_size=64, epochs=10,
validation_data=(x_test, y_test)
)

# Train with importance sampling
model.set_weights(weights)
K.set_value(model.optimizer.lr, 0.01)
ImportanceTraining(model).fit(
x_train, y_train,
batch_size=64, epochs=2,
validation_data=(x_test, y_test)
)

Using the script
****************

The following terminal commands train a small VGG-like network to ~0.65% error
on MNIST (the numbers are from a CPU).
.. code::

$ # Train a small cnn with mnist for 500 mini-batches using importance
$ # sampling with bias to achieve ~ 0.65% error (on the CPU).
$ time ./importance_sampling.py \
> small_cnn \
> oracle-gnorm \
> model \
> predicted \
> mnist \
> /tmp/is \
> --hyperparams 'batch_size=i128;lr=f0.003;lr_reductions=I10000' \
> --train_for 500 --validate_every 500
real 1m41.985s
user 8m14.400s
sys 0m35.900s
$
$ # And with uniform sampling to achieve ~ 0.9% error.
$ time ./importance_sampling.py \
> small_cnn \
> oracle-loss \
> uniform \
> unweighted \
> mnist \
> /tmp/uniform \
> --hyperparams 'batch_size=i128;lr=f0.003;lr_reductions=I10000' \
> --train_for 3000 --validate_every 3000
real 9m23.971s
user 47m32.600s
sys 3m4.188s