An open API service indexing awesome lists of open source software.

https://github.com/tjmahr/tristan

Some helper functions for dealing with RStanARM models
https://github.com/tjmahr/tristan

bayesian rstanarm rstats stan

Last synced: about 1 month ago
JSON representation

Some helper functions for dealing with RStanARM models

Awesome Lists containing this project

README

          

---
output: github_document
---

```{r, echo = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.path = "fig/README-"
)
set.seed(09292017)
```

# tristan [![Project Status: WIP - Initial development is in progress, but
there has not yet been a stable, usable release suitable for the public.](http://www.repostatus.org/badges/latest/wip.svg)](http://www.repostatus.org/#wip)

This package contains my helper functions for working with models fit with
RStanARM. The package is named tristan because I'm working with Stan samples and
my name is Tristan.

I plan to incrementally update this package whenever I find myself solving the
same old problems from an RStanARM model.

## Installation

You can install tristan from github with:

```{r gh-installation, eval = FALSE}
# install.packages("devtools")
devtools::install_github("tjmahr/tristan")
```

## Overview of helpers

`augment_posterior_predict()` and `augment_posterior_linpred()` generate new
data predictions and fitted means for new datasets using RStanARM's
`posterior_predict()` and `posterior_linpred()`. The RStanARM functions return
giant matrices of predicted values, but these functions return a long dataframe
of predicted values along with the values of the predictor variables. The name
_augment_ follows the convention of the broom package where `augment()` refers
to augmenting a data-set with model predictions.

`stan_to_lm()` and `stan_to_glm()` provide a quick way to refit an RStanARM
model with its classical counterpart.

`tidy_etdi()`, `tidy_hdpi()`, `tidy_median()` and `double_etdi()` provide
helpers for interval calclulation.

`draw_coef()`, `draw_fixef()`, `draw_ranef()` and `draw_var_corr()` work like
the `coef()`, `fixef()`, `ranef()`, and `VarCorr()` functions for rstanarm
mixed effects models, but they return a tidy dataframe of posterior samples
instead of point estimates.

## Example

Fit a simple linear model.

```{r model, message = FALSE, results = 'hide'}
library(tidyverse)
library(rstanarm)
library(tristan)

# Scaling makes the model run much faster
scale_v <- function(...) as.vector(scale(...))
iris$z.Sepal.Length <- scale_v(iris$Sepal.Length)
iris$z.Petal.Length <- scale_v(iris$Petal.Length)

# Just to ensure that NA values don't break the prediction function
iris[2:6, "Species"] <- NA

model <- stan_glm(
z.Sepal.Length ~ z.Petal.Length * Species,
data = iris,
family = gaussian(),
prior = normal(0, 2))
```

```{r}
print(model)
```

### Posterior fitted values (linear predictions)

Let's plot some samples of the model's linear prediction for the mean. If
classical model provide a single "line of best fit", Bayesian models provide a
distribution "lines of plausible fit". We'd like to visualize 100 of these lines
alongside the raw data.

In classical models, getting the fitted values is easily done by adding a column
of `fitted()` values to dataframe or using `predict()` on some new observations.

Because the posterior of this model contains 4000 such fitted or predicted
values, more data wrangling and reshaping is required.
`augment_posterior_linpred()` automates this task by producing a long dataframe
with one row per posterior fitted value.

Here, we tell the model that we want just 100 of those lines (i.e., 100 samples
from the posterior distribution).

```{r}
# Get the fitted means of the data for 100 samples of the posterior distribution
linear_preds <- augment_posterior_linpred(
model = model,
newdata = iris,
nsamples = 100)
linear_preds
```

To plot the lines, we have to unscale the model's fitted values.

```{r}
unscale <- function(scaled, original) {
(scaled * sd(original, na.rm = TRUE)) + mean(original, na.rm = TRUE)
}

linear_preds$.posterior_value <- unscale(
scaled = linear_preds$.posterior_value,
original = iris$Sepal.Length)
```

Now, we can do a spaghetti plot of linear predictions.

```{r many-lines-of-best-fit, fig.height = 4, fig.width = 6}
ggplot(iris) +
aes(x = Petal.Length, y = Sepal.Length, color = Species) +
geom_point() +
geom_line(aes(y = .posterior_value, group = interaction(Species, .draw)),
data = linear_preds, alpha = .20)
```

### Posterior predictions (simulated new data)

`augment_posterior_predict()` similarly tidies values from the
`posterior_predict()` function. `posterior_predict()` incorporates the error
terms from the model, so it can be used predict new fake data from the model.

Let's create a range of values within each species, and get posterior
predictions for those values.

```{r}
library(modelr)

# Within each species, generate a sequence of z.Petal.Length values
newdata <- iris %>%
group_by(Species) %>%
# Expand the range x value a little bit so that the points do not bulge out
# left/right sides of the uncertainty ribbon in the plot
data_grid(z.Petal.Length = z.Petal.Length %>%
seq_range(n = 80, expand = .10)) %>%
ungroup()

newdata$Petal.Length <- unscale(newdata$z.Petal.Length, iris$Petal.Length)

# Get posterior predictions
posterior_preds <- augment_posterior_predict(model, newdata)
posterior_preds

posterior_preds$.posterior_value <- unscale(
scaled = posterior_preds$.posterior_value,
original = iris$Sepal.Length)
```

Take a second to appreciate the size of that table. It has 4000 predictions for
each the `r nrow(newdata)` observations in `newdata`.

Now, we might inspect whether 95% of the data falls inside the 95% interval of
posterior-predicted values (among other questions we could ask the model.)

```{r 95-percent-intervals, fig.height = 4, fig.width = 6}
ggplot(iris) +
aes(x = Petal.Length, y = Sepal.Length, color = Species) +
geom_point() +
stat_summary(aes(y = .posterior_value, group = Species, color = NULL),
data = posterior_preds, alpha = 0.4, fill = "grey60",
geom = "ribbon",
fun.data = median_hilow, fun.args = list(conf.int = .95))
```

### Refit RStanARM models with classical counterparts

There are some quick functions for refitting RStanARM models using classical
versions. These functions basically inject the values of `model$formula` and
`model$data` into `lm()` or `glm()`. (Seriously, see the call sections in the
two outputs below.) Therefore, don't use these functions for serious comparisons
of classical versus Bayesian models.

Refit with a linear model:

```{r}
arm::display(stan_to_lm(model))
```

Refit with a generalized linear model:

```{r}
arm::display(stan_to_glm(model))
```

```{r, echo = FALSE, eval = FALSE}
arm::display(glm( z.Sepal.Length ~ z.Petal.Length * Species,
data = iris,
family = gaussian()))
```

### Calculate (classical) _R_2

`calculate_classical_r2()` returns the unadjusted _R_2 for each draw of
the posterior distribution.

```{r}
df_r2 <- data_frame(
R2 = calculate_model_r2(model)
)
df_r2
```

### Interval calculation

Two more helper functions compute tidy data-frames of posterior density
intervals. `tidy_hpdi()` provides the highest-density posterior interval for
model parameters, while `tidy_etdi()` computes the equal-tailed density
intervals (the typical sort of intervals used for uncertain intervals.)

```{r}
tidy_hpdi(model)
tidy_etdi(model)
```

The functions also work on single vectors of numbers, for quick one-off
calculations.

```{r}
r2_intervals <- bind_rows(
tidy_hpdi(calculate_model_r2(model)),
tidy_etdi(calculate_model_r2(model))
)
r2_intervals
```

We can compare the difference between highest-posterior density intervals and
equal-tailed intervals.

```{r r2-histogram, fig.width=6, fig.height=4}
ggplot(df_r2) +
aes(x = R2) +
geom_histogram() +
geom_vline(aes(xintercept = lower, color = interval),
data = r2_intervals, size = 1, linetype = "dashed") +
geom_vline(aes(xintercept = upper, color = interval),
data = r2_intervals, size = 1, linetype = "dashed") +
labs(color = "90% interval",
x = expression(R^2),
y = "Num. posterior samples")
```

## Pure sugar

`double_etdi()` that provides all the values needed to make a double
interval (caterpillar) plot.

```{r double-interval}
df1 <- double_etdi(calculate_model_r2(model), .95, .90)
df2 <- double_etdi(model, .95, .90)
df <- bind_rows(df1, df2)
df

ggplot(df) +
aes(x = estimate, y = term, yend = term) +
geom_vline(color = "white", size = 2, xintercept = 0) +
geom_segment(aes(x = outer_lower, xend = outer_upper)) +
geom_segment(aes(x = inner_lower, xend = inner_upper), size = 2) +
geom_point(size = 4)
```

## Predictions for mixed effects models

This is such a finicky area that it will never 100% of the time but it works
well enough, especially if the new observations are subsets of the original
dataset.

```{r, results = 'hide'}
cbpp <- lme4::cbpp
m2 <- stan_glmer(
cbind(incidence, size - incidence) ~ period + (1 | herd),
data = cbpp, family = binomial,
prior_covariance = decov(4, 1, 1),
prior = normal(0, 1))
```

Add a new herd to the data.

```{r}
# Add a new herd
newdata <- cbpp %>%
tibble::add_row(herd = "16", incidence = 0, size = 20, period = 1) %>%
tibble::add_row(herd = "16", incidence = 0, size = 20, period = 2) %>%
tibble::add_row(herd = "16", incidence = 0, size = 20, period = 3) %>%
tibble::add_row(herd = "16", incidence = 0, size = 20, period = NA)

d <- augment_posterior_predict(m2, newdata) %>%
dplyr::filter(herd == "16")
d
```

Plot the predictions.

```{r mixed-pred, fig.width=3, fig.height=2}
ggplot(d, aes(x = interaction(herd, period), y = .posterior_value / size)) +
stat_summary(fun.data = median_hilow) +
theme_grey(base_size = 8)
```

## Samples for mixed effects models

`coef`, `fixef()` and `ranef()` only give the point estimates for model effects.
`draw_coef()`, `draw_fixef()` and `draw_ranef()` will sample them from the
posterior distribution.

```{r}
ranef(m2)
draw_ranef(m2, 100)

fixef(m2)
draw_fixef(m2, 100)

coef(m2)
draw_coef(m2, 100)
draw_coef(m2, 100) %>%
select(.draw, .group_var, .group, .fixef_parameter, .total) %>%
tidyr::spread(.fixef_parameter, .total)
```

We can visually confirm that `coef()` uses the posterior median values.

```{r test-coef}
all_coefs <- draw_coef(m2)

medians <- coef(m2) %>%
lapply(tibble::rownames_to_column, ".group") %>%
bind_rows(.id = ".group_var") %>%
tidyr::gather(.term, .posterior_median, -.group_var, -.group)

ggplot(all_coefs) +
aes(x = interaction(.group_var, .group), y = .total) +
stat_summary(fun.data = median_hilow) +
facet_wrap(".term") +
geom_point(aes(y = .posterior_median), data = medians, color = "red")
```

`draw_var_corr()` provides an analogue to `VarCorr()`. Let's compare the prior
distribution of correlation terms using different settings for `decov()` prior.

```{r, results = 'hide'}
sleep <- lme4::sleepstudy

m0 <- stan_glmer(
Reaction ~ Days + (Days | Subject),
data = sleep,
prior_covariance = decov(.25, 1, 1),
prior = normal(0, 1),
prior_PD = TRUE,
chains = 1)

m1 <- stan_glmer(
Reaction ~ Days + (Days | Subject),
data = sleep,
prior_covariance = decov(1, 1, 1),
prior = normal(0, 1),
prior_PD = TRUE,
chains = 1)

m2 <- stan_glmer(
Reaction ~ Days + (Days | Subject),
data = sleep,
prior_covariance = decov(2, 1, 1),
prior = normal(0, 1),
prior_PD = TRUE,
chains = 1)

m3 <- stan_glmer(
Reaction ~ Days + (Days | Subject),
data = sleep,
prior_covariance = decov(4, 1, 1),
prior = normal(0, 1),
prior_PD = TRUE,
chains = 1)

m4 <- stan_glmer(
Reaction ~ Days + (Days | Subject),
data = sleep,
prior_covariance = decov(8, 1, 1),
prior = normal(0, 1),
prior_PD = TRUE,
chains = 1)
```

`VarCorr()` returns the mean variance-covariance/standard-deviation-correlation
values.

```{r}
VarCorr(m3)

as.data.frame(VarCorr(m3))
```

In this function we compute the sd-cor matrix for each posterior sample.

```{r var-corrr}
vcs0 <- draw_var_corr(m0)
vcs1 <- draw_var_corr(m1)
vcs2 <- draw_var_corr(m2)
vcs3 <- draw_var_corr(m3)
vcs4 <- draw_var_corr(m4)

vcs3
```

Filter to just the correlations.

```{r var-corr}
cors <- bind_rows(
vcs0 %>% mutate(model = "decov(.25, 1, 1)"),
vcs1 %>% mutate(model = "decov(1, 1, 1)"),
vcs2 %>% mutate(model = "decov(2, 1, 1)"),
vcs3 %>% mutate(model = "decov(4, 1, 1)"),
vcs4 %>% mutate(model = "decov(8, 1, 1)")) %>%
filter(!is.na(var2))

ggplot(cors) +
aes(y = sdcor, x = .parameter, color = model) +
stat_summary(fun.data = median_hilow, geom = "linerange",
fun.args = list(conf.int = .9),
position = position_dodge(-.8)) +
stat_summary(fun.data = median_hilow, size = 1.5, geom = "linerange",
fun.args = list(conf.int = .7), show.legend = FALSE,
position = position_dodge(-.8)) +
stat_summary(fun.data = median_hilow, size = 2.5, geom = "linerange",
fun.args = list(conf.int = .5), show.legend = FALSE,
position = position_dodge(-.8)) +
coord_flip() +
facet_wrap(".parameter") +
theme(axis.text.y = element_blank(),
axis.ticks.y = element_blank(),
panel.grid.major.y = element_blank()) +
labs(x = NULL, y = "correlation", color = "prior") +
ggtitle("Degrees of regularization with LKJ prior")

ggplot(cors) +
aes(x = sdcor, color = model) +
stat_density(geom = "line", adjust = .2) +
facet_wrap(".parameter")
```