https://github.com/tjmahr/tristan
Some helper functions for dealing with RStanARM models
https://github.com/tjmahr/tristan
bayesian rstanarm rstats stan
Last synced: about 1 month ago
JSON representation
Some helper functions for dealing with RStanARM models
- Host: GitHub
- URL: https://github.com/tjmahr/tristan
- Owner: tjmahr
- License: gpl-3.0
- Created: 2017-05-05T15:49:25.000Z (over 8 years ago)
- Default Branch: master
- Last Pushed: 2017-11-02T20:10:56.000Z (about 8 years ago)
- Last Synced: 2025-10-09T12:35:54.973Z (about 1 month ago)
- Topics: bayesian, rstanarm, rstats, stan
- Language: R
- Homepage:
- Size: 441 KB
- Stars: 5
- Watchers: 1
- Forks: 0
- Open Issues: 2
-
Metadata Files:
- Readme: README.Rmd
- License: LICENSE
Awesome Lists containing this project
README
---
output: github_document
---
```{r, echo = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.path = "fig/README-"
)
set.seed(09292017)
```
# tristan [](http://www.repostatus.org/#wip)
This package contains my helper functions for working with models fit with
RStanARM. The package is named tristan because I'm working with Stan samples and
my name is Tristan.
I plan to incrementally update this package whenever I find myself solving the
same old problems from an RStanARM model.
## Installation
You can install tristan from github with:
```{r gh-installation, eval = FALSE}
# install.packages("devtools")
devtools::install_github("tjmahr/tristan")
```
## Overview of helpers
`augment_posterior_predict()` and `augment_posterior_linpred()` generate new
data predictions and fitted means for new datasets using RStanARM's
`posterior_predict()` and `posterior_linpred()`. The RStanARM functions return
giant matrices of predicted values, but these functions return a long dataframe
of predicted values along with the values of the predictor variables. The name
_augment_ follows the convention of the broom package where `augment()` refers
to augmenting a data-set with model predictions.
`stan_to_lm()` and `stan_to_glm()` provide a quick way to refit an RStanARM
model with its classical counterpart.
`tidy_etdi()`, `tidy_hdpi()`, `tidy_median()` and `double_etdi()` provide
helpers for interval calclulation.
`draw_coef()`, `draw_fixef()`, `draw_ranef()` and `draw_var_corr()` work like
the `coef()`, `fixef()`, `ranef()`, and `VarCorr()` functions for rstanarm
mixed effects models, but they return a tidy dataframe of posterior samples
instead of point estimates.
## Example
Fit a simple linear model.
```{r model, message = FALSE, results = 'hide'}
library(tidyverse)
library(rstanarm)
library(tristan)
# Scaling makes the model run much faster
scale_v <- function(...) as.vector(scale(...))
iris$z.Sepal.Length <- scale_v(iris$Sepal.Length)
iris$z.Petal.Length <- scale_v(iris$Petal.Length)
# Just to ensure that NA values don't break the prediction function
iris[2:6, "Species"] <- NA
model <- stan_glm(
z.Sepal.Length ~ z.Petal.Length * Species,
data = iris,
family = gaussian(),
prior = normal(0, 2))
```
```{r}
print(model)
```
### Posterior fitted values (linear predictions)
Let's plot some samples of the model's linear prediction for the mean. If
classical model provide a single "line of best fit", Bayesian models provide a
distribution "lines of plausible fit". We'd like to visualize 100 of these lines
alongside the raw data.
In classical models, getting the fitted values is easily done by adding a column
of `fitted()` values to dataframe or using `predict()` on some new observations.
Because the posterior of this model contains 4000 such fitted or predicted
values, more data wrangling and reshaping is required.
`augment_posterior_linpred()` automates this task by producing a long dataframe
with one row per posterior fitted value.
Here, we tell the model that we want just 100 of those lines (i.e., 100 samples
from the posterior distribution).
```{r}
# Get the fitted means of the data for 100 samples of the posterior distribution
linear_preds <- augment_posterior_linpred(
model = model,
newdata = iris,
nsamples = 100)
linear_preds
```
To plot the lines, we have to unscale the model's fitted values.
```{r}
unscale <- function(scaled, original) {
(scaled * sd(original, na.rm = TRUE)) + mean(original, na.rm = TRUE)
}
linear_preds$.posterior_value <- unscale(
scaled = linear_preds$.posterior_value,
original = iris$Sepal.Length)
```
Now, we can do a spaghetti plot of linear predictions.
```{r many-lines-of-best-fit, fig.height = 4, fig.width = 6}
ggplot(iris) +
aes(x = Petal.Length, y = Sepal.Length, color = Species) +
geom_point() +
geom_line(aes(y = .posterior_value, group = interaction(Species, .draw)),
data = linear_preds, alpha = .20)
```
### Posterior predictions (simulated new data)
`augment_posterior_predict()` similarly tidies values from the
`posterior_predict()` function. `posterior_predict()` incorporates the error
terms from the model, so it can be used predict new fake data from the model.
Let's create a range of values within each species, and get posterior
predictions for those values.
```{r}
library(modelr)
# Within each species, generate a sequence of z.Petal.Length values
newdata <- iris %>%
group_by(Species) %>%
# Expand the range x value a little bit so that the points do not bulge out
# left/right sides of the uncertainty ribbon in the plot
data_grid(z.Petal.Length = z.Petal.Length %>%
seq_range(n = 80, expand = .10)) %>%
ungroup()
newdata$Petal.Length <- unscale(newdata$z.Petal.Length, iris$Petal.Length)
# Get posterior predictions
posterior_preds <- augment_posterior_predict(model, newdata)
posterior_preds
posterior_preds$.posterior_value <- unscale(
scaled = posterior_preds$.posterior_value,
original = iris$Sepal.Length)
```
Take a second to appreciate the size of that table. It has 4000 predictions for
each the `r nrow(newdata)` observations in `newdata`.
Now, we might inspect whether 95% of the data falls inside the 95% interval of
posterior-predicted values (among other questions we could ask the model.)
```{r 95-percent-intervals, fig.height = 4, fig.width = 6}
ggplot(iris) +
aes(x = Petal.Length, y = Sepal.Length, color = Species) +
geom_point() +
stat_summary(aes(y = .posterior_value, group = Species, color = NULL),
data = posterior_preds, alpha = 0.4, fill = "grey60",
geom = "ribbon",
fun.data = median_hilow, fun.args = list(conf.int = .95))
```
### Refit RStanARM models with classical counterparts
There are some quick functions for refitting RStanARM models using classical
versions. These functions basically inject the values of `model$formula` and
`model$data` into `lm()` or `glm()`. (Seriously, see the call sections in the
two outputs below.) Therefore, don't use these functions for serious comparisons
of classical versus Bayesian models.
Refit with a linear model:
```{r}
arm::display(stan_to_lm(model))
```
Refit with a generalized linear model:
```{r}
arm::display(stan_to_glm(model))
```
```{r, echo = FALSE, eval = FALSE}
arm::display(glm( z.Sepal.Length ~ z.Petal.Length * Species,
data = iris,
family = gaussian()))
```
### Calculate (classical) _R_2
`calculate_classical_r2()` returns the unadjusted _R_2 for each draw of
the posterior distribution.
```{r}
df_r2 <- data_frame(
R2 = calculate_model_r2(model)
)
df_r2
```
### Interval calculation
Two more helper functions compute tidy data-frames of posterior density
intervals. `tidy_hpdi()` provides the highest-density posterior interval for
model parameters, while `tidy_etdi()` computes the equal-tailed density
intervals (the typical sort of intervals used for uncertain intervals.)
```{r}
tidy_hpdi(model)
tidy_etdi(model)
```
The functions also work on single vectors of numbers, for quick one-off
calculations.
```{r}
r2_intervals <- bind_rows(
tidy_hpdi(calculate_model_r2(model)),
tidy_etdi(calculate_model_r2(model))
)
r2_intervals
```
We can compare the difference between highest-posterior density intervals and
equal-tailed intervals.
```{r r2-histogram, fig.width=6, fig.height=4}
ggplot(df_r2) +
aes(x = R2) +
geom_histogram() +
geom_vline(aes(xintercept = lower, color = interval),
data = r2_intervals, size = 1, linetype = "dashed") +
geom_vline(aes(xintercept = upper, color = interval),
data = r2_intervals, size = 1, linetype = "dashed") +
labs(color = "90% interval",
x = expression(R^2),
y = "Num. posterior samples")
```
## Pure sugar
`double_etdi()` that provides all the values needed to make a double
interval (caterpillar) plot.
```{r double-interval}
df1 <- double_etdi(calculate_model_r2(model), .95, .90)
df2 <- double_etdi(model, .95, .90)
df <- bind_rows(df1, df2)
df
ggplot(df) +
aes(x = estimate, y = term, yend = term) +
geom_vline(color = "white", size = 2, xintercept = 0) +
geom_segment(aes(x = outer_lower, xend = outer_upper)) +
geom_segment(aes(x = inner_lower, xend = inner_upper), size = 2) +
geom_point(size = 4)
```
## Predictions for mixed effects models
This is such a finicky area that it will never 100% of the time but it works
well enough, especially if the new observations are subsets of the original
dataset.
```{r, results = 'hide'}
cbpp <- lme4::cbpp
m2 <- stan_glmer(
cbind(incidence, size - incidence) ~ period + (1 | herd),
data = cbpp, family = binomial,
prior_covariance = decov(4, 1, 1),
prior = normal(0, 1))
```
Add a new herd to the data.
```{r}
# Add a new herd
newdata <- cbpp %>%
tibble::add_row(herd = "16", incidence = 0, size = 20, period = 1) %>%
tibble::add_row(herd = "16", incidence = 0, size = 20, period = 2) %>%
tibble::add_row(herd = "16", incidence = 0, size = 20, period = 3) %>%
tibble::add_row(herd = "16", incidence = 0, size = 20, period = NA)
d <- augment_posterior_predict(m2, newdata) %>%
dplyr::filter(herd == "16")
d
```
Plot the predictions.
```{r mixed-pred, fig.width=3, fig.height=2}
ggplot(d, aes(x = interaction(herd, period), y = .posterior_value / size)) +
stat_summary(fun.data = median_hilow) +
theme_grey(base_size = 8)
```
## Samples for mixed effects models
`coef`, `fixef()` and `ranef()` only give the point estimates for model effects.
`draw_coef()`, `draw_fixef()` and `draw_ranef()` will sample them from the
posterior distribution.
```{r}
ranef(m2)
draw_ranef(m2, 100)
fixef(m2)
draw_fixef(m2, 100)
coef(m2)
draw_coef(m2, 100)
draw_coef(m2, 100) %>%
select(.draw, .group_var, .group, .fixef_parameter, .total) %>%
tidyr::spread(.fixef_parameter, .total)
```
We can visually confirm that `coef()` uses the posterior median values.
```{r test-coef}
all_coefs <- draw_coef(m2)
medians <- coef(m2) %>%
lapply(tibble::rownames_to_column, ".group") %>%
bind_rows(.id = ".group_var") %>%
tidyr::gather(.term, .posterior_median, -.group_var, -.group)
ggplot(all_coefs) +
aes(x = interaction(.group_var, .group), y = .total) +
stat_summary(fun.data = median_hilow) +
facet_wrap(".term") +
geom_point(aes(y = .posterior_median), data = medians, color = "red")
```
`draw_var_corr()` provides an analogue to `VarCorr()`. Let's compare the prior
distribution of correlation terms using different settings for `decov()` prior.
```{r, results = 'hide'}
sleep <- lme4::sleepstudy
m0 <- stan_glmer(
Reaction ~ Days + (Days | Subject),
data = sleep,
prior_covariance = decov(.25, 1, 1),
prior = normal(0, 1),
prior_PD = TRUE,
chains = 1)
m1 <- stan_glmer(
Reaction ~ Days + (Days | Subject),
data = sleep,
prior_covariance = decov(1, 1, 1),
prior = normal(0, 1),
prior_PD = TRUE,
chains = 1)
m2 <- stan_glmer(
Reaction ~ Days + (Days | Subject),
data = sleep,
prior_covariance = decov(2, 1, 1),
prior = normal(0, 1),
prior_PD = TRUE,
chains = 1)
m3 <- stan_glmer(
Reaction ~ Days + (Days | Subject),
data = sleep,
prior_covariance = decov(4, 1, 1),
prior = normal(0, 1),
prior_PD = TRUE,
chains = 1)
m4 <- stan_glmer(
Reaction ~ Days + (Days | Subject),
data = sleep,
prior_covariance = decov(8, 1, 1),
prior = normal(0, 1),
prior_PD = TRUE,
chains = 1)
```
`VarCorr()` returns the mean variance-covariance/standard-deviation-correlation
values.
```{r}
VarCorr(m3)
as.data.frame(VarCorr(m3))
```
In this function we compute the sd-cor matrix for each posterior sample.
```{r var-corrr}
vcs0 <- draw_var_corr(m0)
vcs1 <- draw_var_corr(m1)
vcs2 <- draw_var_corr(m2)
vcs3 <- draw_var_corr(m3)
vcs4 <- draw_var_corr(m4)
vcs3
```
Filter to just the correlations.
```{r var-corr}
cors <- bind_rows(
vcs0 %>% mutate(model = "decov(.25, 1, 1)"),
vcs1 %>% mutate(model = "decov(1, 1, 1)"),
vcs2 %>% mutate(model = "decov(2, 1, 1)"),
vcs3 %>% mutate(model = "decov(4, 1, 1)"),
vcs4 %>% mutate(model = "decov(8, 1, 1)")) %>%
filter(!is.na(var2))
ggplot(cors) +
aes(y = sdcor, x = .parameter, color = model) +
stat_summary(fun.data = median_hilow, geom = "linerange",
fun.args = list(conf.int = .9),
position = position_dodge(-.8)) +
stat_summary(fun.data = median_hilow, size = 1.5, geom = "linerange",
fun.args = list(conf.int = .7), show.legend = FALSE,
position = position_dodge(-.8)) +
stat_summary(fun.data = median_hilow, size = 2.5, geom = "linerange",
fun.args = list(conf.int = .5), show.legend = FALSE,
position = position_dodge(-.8)) +
coord_flip() +
facet_wrap(".parameter") +
theme(axis.text.y = element_blank(),
axis.ticks.y = element_blank(),
panel.grid.major.y = element_blank()) +
labs(x = NULL, y = "correlation", color = "prior") +
ggtitle("Degrees of regularization with LKJ prior")
ggplot(cors) +
aes(x = sdcor, color = model) +
stat_density(geom = "line", adjust = .2) +
facet_wrap(".parameter")
```