Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/lorenzwalthert/kerasmisc
Add-on functionality for the R implementation of Keras
https://github.com/lorenzwalthert/kerasmisc
callback keras learning-rate
Last synced: 6 days ago
JSON representation
Add-on functionality for the R implementation of Keras
- Host: GitHub
- URL: https://github.com/lorenzwalthert/kerasmisc
- Owner: lorenzwalthert
- License: gpl-3.0
- Created: 2018-08-04T14:57:13.000Z (over 6 years ago)
- Default Branch: master
- Last Pushed: 2021-05-02T22:12:27.000Z (over 3 years ago)
- Last Synced: 2024-11-09T21:44:09.320Z (2 months ago)
- Topics: callback, keras, learning-rate
- Language: R
- Homepage:
- Size: 517 KB
- Stars: 3
- Watchers: 5
- Forks: 1
- Open Issues: 0
-
Metadata Files:
- Readme: README.Rmd
- License: LICENSE.md
Awesome Lists containing this project
README
---
output: github_document
---[![R-CMD-check](https://github.com/lorenzwalthert/KerasMisc/workflows/R-CMD-check/badge.svg)](https://github.com/lorenzwalthert/KerasMisc/actions)
```{r setup, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.path = "man/figures/README-",
out.width = "100%"
)
```[![lifecycle](https://img.shields.io/badge/lifecycle-experimental-orange.svg)](https://www.tidyverse.org/lifecycle/#experimental)
[![Travis build status](https://travis-ci.org/lorenzwalthert/KerasMisc.svg?branch=master)](https://travis-ci.org/lorenzwalthert/KerasMisc)
[![Coverage status](https://codecov.io/gh/lorenzwalthert/KerasMisc/branch/master/graph/badge.svg)](https://codecov.io/github/lorenzwalthert/KerasMisc?branch=master)# KerasMisc
The goal of KerasMisc is to provide a collection of tools that enhance the R
implementation of Keras. Currently, the package features:* a Keras callback for cyclical learning rate scheduling as proposed by
[Smith (2017)](https://arxiv.org/abs/1506.01186), closely adapted from the
[Python implementation](https://github.com/bckenstler/CLR) and then extended
so they the bands are scaled by a constant factor (typically < 1) after the
validation loss has not improved for a while. For details, see the
[README](https://github.com/bckenstler/CLR) from the Python
implementation and the example below for dynamically adjusting bandwidths.Contributions welcome.
## Installation
You can install the development version of KerasMisc from GitHub with
```{r, eval = FALSE}
remotes::install_github("lorenzwalthert/KerasMisc")
```## Features
**Keras callbacks**
Let's create a model
```{r}
library(keras)
library(KerasMisc)
dataset <- dataset_boston_housing()
c(c(train_data, train_targets), c(test_data, test_targets)) %<-% datasetmean <- apply(train_data, 2, mean)
std <- apply(train_data, 2, sd)
train_data <- scale(train_data, center = mean, scale = std)
test_data <- scale(test_data, center = mean, scale = std)model <- keras_model_sequential() %>%
layer_dense(
units = 64, activation = "relu",
input_shape = dim(train_data)[[2]]
) %>%
layer_dense(units = 64, activation = "relu") %>%
layer_dense(units = 1)
model %>% compile(
optimizer = optimizer_rmsprop(lr = 0.001),
loss = "mse",
metrics = c("mae")
)
```Next, we can fit the model with a learning rate schedule. We dynamically adjust
the bandwidths of the learnin rate (multiplication with 0.9) whenever the
validation loss does not decrease for three epochs. When decreased, we wait 2
epochs (`cooldown`) before we set in the patience counter again.```{r}
iter_per_epoch <- nrow(train_data) / 32
callback_clr <- new_callback_cyclical_learning_rate(
step_size = iter_per_epoch * 2,
base_lr = 0.001,
max_lr = 0.006,
mode = "triangular",
patience = 3,
factor = 0.9,
cooldown = 2,
verbose = 0
)
model %>% fit(
train_data, train_targets,
validation_data = list(test_data, test_targets),
epochs = 50, verbose = 0,
callbacks = list(callback_clr)
)
```We can now have a look at the learning rates:
```{r}
head(callback_clr$history)
``````{r plot-clr}
backend <- ifelse(rlang::is_installed("ggplot2"), "ggplot2", "base")
plot_clr_history(callback_clr, granularity = "iteration", backend = backend)```