Ecosyste.ms: Awesome
An open API service indexing awesome lists of open source software.
https://github.com/alexhallam/media_mix_sim
Media Mix Model with simulated data and stan
https://github.com/alexhallam/media_mix_sim
media media-mix-modeling roi simulation
Last synced: 11 days ago
JSON representation
Media Mix Model with simulated data and stan
- Host: GitHub
- URL: https://github.com/alexhallam/media_mix_sim
- Owner: alexhallam
- Created: 2018-05-22T18:20:32.000Z (over 6 years ago)
- Default Branch: master
- Last Pushed: 2020-02-16T05:28:57.000Z (almost 5 years ago)
- Last Synced: 2024-12-06T17:30:32.242Z (2 months ago)
- Topics: media, media-mix-modeling, roi, simulation
- Language: Stan
- Size: 2.7 MB
- Stars: 6
- Watchers: 4
- Forks: 5
- Open Issues: 1
-
Metadata Files:
- Readme: README.Rmd
Awesome Lists containing this project
README
---
output: github_document
---```{r, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>"
)
``````{r media data}
library(tidyverse)
library(corrr)
library(hrbrthemes)
# -------------------------------------
# Generate Media Data
# Equation: y=a*sin(b*t)+c.unif*amp
# -------------------------------------
set.seed(1)
n <- 52 * 2 # number of data points
t <- seq(0, 4*pi, length.out = n)
b <- 8 # essentially the number of pillars in a year
c.norm1 <- rnorm(n,0,0.5)
c.norm2 <- rnorm(n,0, 0.75)
c.norm3 <- rnorm(n,0, 0.75)
amp <- 2
# generate data and calculate "y"
media_tv <- 1*sin(b*t)+c.norm1*amp # Gaussian/normal error
media_radio <- 1*sin(b*t)+c.norm2*amp # Gaussian/normal error
media_online <- 1*sin(b*t)+c.norm3*amp # Gaussian/normal error
week <- seq(1:104)sim_df <- as.data.frame(list(week = week,
media_tv = media_tv,
media_radio = media_radio,
media_online = media_online)) %>%
mutate_at(vars(media_tv, media_radio,media_online), percent_rank)sim_df %>%
ggplot(aes(x = week)) +
geom_line(aes(y = media_tv, color = "tv")) +
geom_line(aes(y = media_radio, color = "radio")) +
geom_line(aes(y = media_online, color = "online")) +
theme_ipsum()
``````{r price data}
# -----------------------------------------------------------------------------
# Generate Menu Price Data
# Equation: y=ARIMA(1,1,0) with AR = 0.50
# Interpretation: Price has increased by 3 dollars in the last year
# -----------------------------------------------------------------------------
price <- arima.sim(n = n, list(ar = c(0.8897, -0.4858), ma = c(0.279, 0.2488)), sd = sqrt(0.001))
mean(price); sd(price)
hist(price)
plot(price)#price <- as.numeric(scale(price))
sim_df <- add_column(sim_df, price)sim_df %>%
gather(key = vartype, value = value, -week) %>%
ggplot(aes(x = week, y = value)) +
geom_col()+
facet_wrap(~vartype) +
hrbrthemes::theme_ipsum()
``````{r}
# --------------------------------------------------
# Scale media to [0, 1] as per paper
# media_i = x_i - min(x) / max(x) - min(x)
# --------------------------------------------------
my_normalizer <- function(x) (x - min(x)) / (max(x) - min(x))sim_df <- sim_df %>%
mutate_at(vars(-week,-price), my_normalizer)sim_df %>% write_csv("sim_df.csv")
sim_df %>%
gather(key = vartype, value = value, -week) %>%
ggplot(aes(x = week, y = value)) +
geom_col()+
facet_wrap(~vartype) +
hrbrthemes::theme_ipsum()
``````{r adstock}
#-----------------------------------------------------------------------------
# Generate adstock as described in the Google Paper:
# Bayesian Methods for Media Mix Modeling with Carryover and Shape Effects
#-----------------------------------------------------------------------------
library(tidyquant)# fake data
date <- seq(from = as.Date("2018-01-01"), length.out = 104, by = "1 week")
x <- c(rep(100,5), rep(0,99)) %>%
as_tibble() %>%
add_column(date) %>%
rename(x = value)# -----------------------------------------------------------------------------
# Name: Carryover Effect
#
# Description: Two functions are provided for modeling the decay of ad
# effect.
#
# Geometric: This function assumes that week 1 is the most impactfull
# week of the the promo. Subsequent weeks have a slow decline
# as defined by the rate. A larger rate give a slower decline
#
# Delayed: This function assumes that a week after week 1 is the most
# impactfull week. It has a weight that is proportional to
# The normal distribution around the week of impact defined
# by theta
# -----------------------------------------------------------------------------
geom_decay <- function(rate,l,...) sum((rate^l) *...) / sum(rate^l)
delayed_decay <- function(rate,l,theta,...){
sum((rate^(l-theta)^2) *...) / sum(rate^(l-theta)^2)
}# Examples of calculating adstock from both functions
# Since the values are calculated on a rolling window
# it is neccesary to used something like tq_mutate
# to get values for a given time-seriesL = 13
x %>%
tq_mutate(select = x, mutate_fun = rollapply, width = L, align = "right",
FUN = geom_decay,
#function args
rate = 0.8,
l = seq(from = 0 , to = L-1),
#ts_mutate
col_rename = "adstock_geometric_decay"
)
x %>%
tq_mutate(select = x, mutate_fun = rollapply, width = L, align = "right",
FUN = delayed_decay,
#function args
rate = 0.8, # rate of 0.4 to 0.8 is sensible
theta = 1, # theta should be about 1 to 3
l = seq(from = 0 , to = L-1),
#ts_mutate
col_rename = "adstock_delayed_decay"
)
``````{r bhill}
# -----------------------------------------------------------------------------
# Name: Shape Effect
#
# Description: It is not enough to model the decay and the lag of an ad.
# The shape of its saturation is also an important funtion
# that deserves attention.
#
# Hill Function: Marketing Mix Modelers often chose between S-curves and
# C-curves when modeling media impact on sales.
# Pharmacology uses the Hill function to model receptors.
# It provides a flexible functional form that may take the
# form of both an S-curve and a C-curve which provides a
# convinient solution to parameterizing the function
# representing shape effect.
# K: Half Saturation
# S: Slope
# B: Beta
#
# Problem: It may be the case that the Slope parameter "S" may have to
# be set to 1 (S = 1). This is an issue with identifiability
# -----------------------------------------------------------------------------# Define Function
BHill <- function(B,K,S,...) B - ((K^S * B)/(...^S + K^S))# set up example data
x <- seq(0,1, length.out = 100) # media must be transformed to [0, 1] scale
# for ease of useparams <- tribble(
~K, ~S, ~B, ~type,
0.5, 1, 0.3, "simple_c",
0.5, 2, 0.3, "simple_s",
0.5, 0.25, 0.3, "sharp_c"
)bhill_df <- crossing(x,params) %>%
mutate(y = BHill(B,K,S,x))bhill_df %>%
ggplot(aes(x = x, y = y, color = type)) +
geom_line() +
labs(title = "Flexible Shape Function") +
hrbrthemes::theme_ipsum()
``````{r}
# =============================================================================
# Simulation
# Description: With simulated media data as "media variables" and
# simulated price as a "control variable" I applied the neccesary
# transfomations (adstock & shape) to the input variables with
# various parameters to test the ability of this model to
# discover the parameters I set
#
# Equation: Weekly sales have the following form:
#
# sales_wk = tau + BHill_tv_wk + BHill_online_wk + BHill_radio_wk
# + gamma*price_wk + e_wk
# =============================================================================#------------------------------------------------------------
#
# Media Parameters
# ----------------
# Parameter | Media_tv | Media_radio | Media_online
# rate 0.6 0.8 0.8
# theta 5 3 4
# K 0.2 0.2 0.2
# S 1 2 2
# B 0.8 0.6 0.3
#
# Other variables
# ----------------
# Parameter | Value
# L 13
# tau 4
# gamma 0.05
# e normal(0,0.05^2)
#------------------------------------------------------------
fat_data <- sim_df %>%
add_column(date = seq(as.Date("2017-01-01"), length.out = n, by = "week")) %>%
# adstocks
tq_mutate(select = media_tv, mutate_fun = rollapply, width = L, align = "right",
FUN = delayed_decay,rate = 0.6, theta = 1,
l = seq(from = 0 , to = L-1),col_rename = "adstk_tv"
)%>%
tq_mutate(select = media_radio, mutate_fun = rollapply, width = L, align = "right",
FUN = delayed_decay,rate = 0.8, theta = 1,
l = seq(from = 0 , to = L-1),col_rename = "adstk_radio"
)%>%
tq_mutate(select = media_online, mutate_fun = rollapply, width = L, align = "right",
FUN = delayed_decay,rate = 0.8, theta = 1,
l = seq(from = 0 , to = L-1),col_rename = "adstk_online"
)%>%
# Shape
mutate(m_tv = BHill(K = 0.2, S = 1, B = 0.8,adstk_tv)) %>%
mutate(m_rd = BHill(K = 0.2, S = 1, B = 0.6,adstk_radio)) %>%
mutate(m_online = BHill(K = 0.2, S = 1, B = 0.3,adstk_online)) %>%
#and errors
mutate(e = rnorm(n = n(), mean = 0, sd = 0.25^2)) %>%
mutate(sales = 4 + m_tv + m_rd + m_online + .5 * price + e)clean_data <- fat_data %>%
select(date, sales,m_tv,m_rd,m_online,price,e) %>%
na.omit()
clean_data
``````{r}
library(hrbrthemes)
# some plots of the data. see if it matches the paper okay
clean_data %>%
ggplot(aes(date, sales)) + geom_line() + theme_ipsum()
``````{r}
clean_data %>%
select(price, m_tv, m_rd, m_online) %>%
correlate()
``````{r}
clean_data %>%
select(sales, m_tv, m_rd, m_online, e, price) %>%
summarise_all(var) %>%
transmute(var_tv = m_tv / sales,
var_rd = m_rd / sales,
var_online = m_online / sales,
var_noise = e / sales,
price = price / sales)
```
```{r}
clean_data %>% write_csv("clean_data.csv")
``````{r}
media_data <- clean_data %>% select(contains("m_"))
# data Prep
N <- nrow(clean_data)
Y <- clean_data$sales
max_lag <- 13
num_media <- 3
lag_vec <- seq(0, max_lag - 1)
X_media <- array(data = media_data, dim = c(num_media))
num_ctrl <- 1
X_ctrl <- clean_data$pricestan_data <- list(N=N, Y=Y, max_lag=max_lag, num_media=num_media,
lag_vec=lag_vec,X_media=X_media,
num_ctrl=num_ctrl,X_ctrl=X_ctrl)stan_data %>% str
``````{r}
library(rstan)
clean_data <- read_csv("clean_data.csv")
media_data <- clean_data %>% select(contains("m_"))
long_media_array <- c(clean_data$m_tv,clean_data$m_rd,clean_data$m_online)
# data Prep
N <- nrow(clean_data)
Y <- clean_data$sales
max_lag <- 13
num_media <- 3
lag_vec <- seq(0, max_lag - 1)
X_media <- array(data = media_data, dim = c(3,13))
X_media <- array(data = long_media_array, dim = c(92,3,13))
num_ctrl <- 1
X_ctrl <- clean_data %>% select(price) %>% as.vector()stan_data <- list(N=N, Y=Y, max_lag=max_lag, num_media=num_media,
lag_vec=lag_vec,X_media=X_media,
num_ctrl=num_ctrl,X_ctrl=X_ctrl)m.stan <- stan(file = "model.stan",data = stan_data, iter = 3000, chains = 1, control = list(max_treedepth = 15))
#summary(m.stan)
``````{r}
m.stan
``````{r}
rstan::get_posterior_mean(m.stan)
list_of_draws <- extract(m.stan)predicted_sales <- summary(m.stan, pars = "mu", probs = NULL)$summary %>%
as_tibble() %>%
select(mean) %>%
rename(pred_sales = mean)pred_and_sales <- predicted_sales %>%
add_column(sales = clean_data$sales) %>%
mutate(index = row_number())pred_and_sales %>%
ggplot(aes(x = index)) +
geom_line(aes(y = sales), color = "black") +
geom_line(aes(y = pred_sales), color = "red")
``````{r}
#look at functions learned from modelx <- seq(0,1, length.out = 100)
tv_pred <- BHill(B = 1.20, K = 0.50, S = 2.23,x)
rd_pred <- BHill(B = 0.95, K = 0.50, S = 2.45,x)
online_pred <- BHill(B = 0.90, K = 0.50, S = 1.59,x)m_tv <- BHill(K = 0.2, S = 1, B = 0.8,x)
m_rd <- BHill(K = 0.2, S = 1, B = 0.6,x)
m_online <- BHill(K = 0.2, S = 1, B = 0.3,x)as_tibble(list(tv_actual = m_tv, tv_pred = tv_pred)) %>%
mutate(index = row_number()) %>%
ggplot(aes(x = index)) +
geom_line(aes(y = tv_actual, color = "actual")) +
geom_line(aes(y = tv_pred, color = "pred"))as_tibble(list(rd_actual = m_rd, rd_pred = rd_pred)) %>%
mutate(index = row_number()) %>%
ggplot(aes(x = index)) +
geom_line(aes(y = rd_actual, color = "actual")) +
geom_line(aes(y = rd_pred, color = "pred"))as_tibble(list(online_actual = m_online, online_pred = online_pred)) %>%
mutate(index = row_number()) %>%
ggplot(aes(x = index)) +
geom_line(aes(y = online_actual, color = "actual")) +
geom_line(aes(y = online_pred, color = "pred"))
``````{r}
rgamma(n = 100, shape = 2, scale = .25) %>% hist()
```